Files
hermes-agent/tools/batch_executor.py
Alexander Whitestone 9f0c410481
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Successful in 35s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 37s
Tests / e2e (pull_request) Successful in 1m48s
Tests / test (pull_request) Failing after 36m13s
feat: batch tool execution with parallel safety checks (#749)
Centralized safety classification for tool call batches:

tools/batch_executor.py (new):
- classify_tool_calls() — classifies batch into parallel_safe,
  path_scoped, sequential, never_parallel tiers
- BatchExecutionPlan — structured plan with parallel and sequential batches
- Path conflict detection — write_file + patch on same file go sequential
- Destructive command detection — rm, mv, sed -i, redirects
- execute_parallel_batch() — ThreadPoolExecutor for concurrent execution

tools/registry.py (enhanced):
- ToolEntry.parallel_safe field — tools can declare parallel safety
- registry.register() accepts parallel_safe=True parameter
- registry.get_parallel_safe_tools() — query registry-declared safe tools

Safety tiers:
- parallel_safe: read_file, web_search, search_files, etc.
- path_scoped: write_file, patch (concurrent when paths don't overlap)
- sequential: terminal, delegate_task, unknown tools
- never_parallel: clarify (requires user interaction)

19 tests passing.
2026-04-15 22:17:16 -04:00

295 lines
10 KiB
Python

"""Batch Tool Executor — Parallel safety classification and concurrent execution.
Provides centralized classification of tool calls into parallel-safe vs sequential,
and utilities for batch execution with safety checks.
Classification tiers:
- PARALLEL_SAFE: read-only tools, no shared state (web_search, read_file, etc.)
- PATH_SCOPED: file operations that can run concurrently when paths don't overlap
- SEQUENTIAL: writes, destructive ops, terminal commands, delegation
- NEVER_PARALLEL: clarify (requires user interaction)
Usage:
from tools.batch_executor import classify_tool_calls, BatchExecutionPlan
plan = classify_tool_calls(tool_calls)
if plan.can_parallelize:
execute_concurrent(plan.parallel_batch)
execute_sequential(plan.sequential_batch)
"""
import json
import logging
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
logger = logging.getLogger(__name__)
# ── Safety Classification ──────────────────────────────────────────────────
# Tools that can ALWAYS run in parallel (read-only, no shared state)
DEFAULT_PARALLEL_SAFE = frozenset({
"ha_get_state",
"ha_list_entities",
"ha_list_services",
"read_file",
"search_files",
"session_search",
"skill_view",
"skills_list",
"vision_analyze",
"web_extract",
"web_search",
"fact_store",
"fact_search",
"session_search",
})
# File tools that can run concurrently ONLY when paths don't overlap
PATH_SCOPED_TOOLS = frozenset({"read_file", "write_file", "patch"})
# Tools that must NEVER run in parallel (require user interaction, shared mutable state)
NEVER_PARALLEL = frozenset({"clarify"})
# Patterns that indicate terminal commands may modify/delete files
DESTRUCTIVE_PATTERNS = re.compile(
r"""(?:^|\s|&&|\|\||;|`)(?:
rm\s|rmdir\s|
mv\s|
sed\s+-i|
truncate\s|
dd\s|
shred\s|
git\s+(?:reset|clean|checkout)\s
)""",
re.VERBOSE,
)
# Output redirects that overwrite files (> but not >>)
REDIRECT_OVERWRITE = re.compile(r'[^>]>[^>]|^>[^>]')
def is_destructive_command(cmd: str) -> bool:
"""Check if a terminal command modifies/deletes files."""
if not cmd:
return False
if DESTRUCTIVE_PATTERNS.search(cmd):
return True
if REDIRECT_OVERWRITE.search(cmd):
return True
return False
def _paths_overlap(path1: Path, path2: Path) -> bool:
"""Check if two paths could conflict (one is ancestor of the other)."""
try:
path1 = path1.resolve()
path2 = path2.resolve()
return path1 == path2 or path1 in path2.parents or path2 in path1.parents
except Exception:
return True # conservative: assume overlap
def _extract_path(tool_name: str, args: dict) -> Optional[Path]:
"""Extract the target path from tool arguments for path-scoped tools."""
if tool_name not in PATH_SCOPED_TOOLS:
return None
raw_path = args.get("path")
if not isinstance(raw_path, str) or not raw_path.strip():
return None
try:
return Path(raw_path).expanduser().resolve()
except Exception:
return None
# ── Classification ─────────────────────────────────────────────────────────
@dataclass
class ToolCallClassification:
"""Classification result for a single tool call."""
tool_name: str
args: dict
tool_call: Any # the original tool_call object
tier: str # "parallel_safe", "path_scoped", "sequential", "never_parallel"
reason: str = ""
@dataclass
class BatchExecutionPlan:
"""Plan for executing a batch of tool calls."""
classifications: List[ToolCallClassification] = field(default_factory=list)
parallel_batch: List[ToolCallClassification] = field(default_factory=list)
sequential_batch: List[ToolCallClassification] = field(default_factory=list)
@property
def can_parallelize(self) -> bool:
return len(self.parallel_batch) > 1
@property
def total(self) -> int:
return len(self.classifications)
def classify_single_tool_call(
tool_call: Any,
extra_parallel_safe: Set[str] = None,
) -> ToolCallClassification:
"""Classify a single tool call into its safety tier."""
tool_name = tool_call.function.name
try:
args = json.loads(tool_call.function.arguments)
except Exception:
return ToolCallClassification(
tool_name=tool_name, args={}, tool_call=tool_call,
tier="sequential", reason="Could not parse arguments"
)
if not isinstance(args, dict):
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="sequential", reason="Non-dict arguments"
)
# Check never-parallel
if tool_name in NEVER_PARALLEL:
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="never_parallel", reason="Requires user interaction"
)
# Check parallel-safe FIRST (before path_scoped) so read_file/search_files
# get classified as parallel_safe even though they have paths
parallel_safe_set = DEFAULT_PARALLEL_SAFE
if extra_parallel_safe:
parallel_safe_set = parallel_safe_set | extra_parallel_safe
if tool_name in parallel_safe_set:
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="parallel_safe", reason="Read-only, no shared state"
)
# Check terminal commands for destructive operations
if tool_name == "terminal":
cmd = args.get("command", "")
if is_destructive_command(cmd):
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="sequential", reason=f"Destructive command: {cmd[:50]}"
)
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="sequential", reason="Terminal command (conservative)"
)
# Check path-scoped tools (write_file, patch — not read_file which is parallel_safe)
if tool_name in PATH_SCOPED_TOOLS:
path = _extract_path(tool_name, args)
if path:
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="path_scoped", reason=f"Path: {path}"
)
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="sequential", reason="Path-scoped but no path found"
)
# Default: sequential (conservative)
return ToolCallClassification(
tool_name=tool_name, args=args, tool_call=tool_call,
tier="sequential", reason="Not classified as parallel-safe"
)
def classify_tool_calls(
tool_calls: list,
extra_parallel_safe: Set[str] = None,
) -> BatchExecutionPlan:
"""Classify a batch of tool calls and produce an execution plan."""
plan = BatchExecutionPlan()
reserved_paths: List[Path] = []
for tc in tool_calls:
classification = classify_single_tool_call(tc, extra_parallel_safe)
plan.classifications.append(classification)
if classification.tier == "never_parallel":
plan.sequential_batch.append(classification)
continue
if classification.tier == "sequential":
plan.sequential_batch.append(classification)
continue
if classification.tier == "path_scoped":
path = _extract_path(classification.tool_name, classification.args)
if path is None:
classification.tier = "sequential"
classification.reason = "Path extraction failed"
plan.sequential_batch.append(classification)
continue
# Check for path conflicts with already-scheduled parallel calls
conflict = any(_paths_overlap(path, existing) for existing in reserved_paths)
if conflict:
classification.tier = "sequential"
classification.reason = f"Path conflict: {path}"
plan.sequential_batch.append(classification)
else:
reserved_paths.append(path)
plan.parallel_batch.append(classification)
continue
if classification.tier == "parallel_safe":
plan.parallel_batch.append(classification)
continue
# Fallback
plan.sequential_batch.append(classification)
return plan
# ── Concurrent Execution ───────────────────────────────────────────────────
def execute_parallel_batch(
batch: List[ToolCallClassification],
invoke_fn: Callable,
max_workers: int = 8,
) -> List[Tuple[str, str]]:
"""Execute parallel-safe tool calls concurrently.
Args:
batch: List of classified tool calls (parallel_safe or path_scoped)
invoke_fn: Function(tool_name, args) -> result_string
max_workers: Max concurrent threads
Returns:
List of (tool_call_id, result_string) tuples
"""
results = []
with ThreadPoolExecutor(max_workers=min(max_workers, len(batch))) as executor:
future_to_tc = {}
for tc in batch:
future = executor.submit(invoke_fn, tc.tool_name, tc.args)
future_to_tc[future] = tc
for future in as_completed(future_to_tc):
tc = future_to_tc[future]
try:
result = future.result()
except Exception as e:
result = json.dumps({"error": str(e)})
tool_call_id = getattr(tc.tool_call, "id", None) or ""
results.append((tool_call_id, result))
return results