Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Successful in 35s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 37s
Tests / e2e (pull_request) Successful in 1m48s
Tests / test (pull_request) Failing after 36m13s
Centralized safety classification for tool call batches: tools/batch_executor.py (new): - classify_tool_calls() — classifies batch into parallel_safe, path_scoped, sequential, never_parallel tiers - BatchExecutionPlan — structured plan with parallel and sequential batches - Path conflict detection — write_file + patch on same file go sequential - Destructive command detection — rm, mv, sed -i, redirects - execute_parallel_batch() — ThreadPoolExecutor for concurrent execution tools/registry.py (enhanced): - ToolEntry.parallel_safe field — tools can declare parallel safety - registry.register() accepts parallel_safe=True parameter - registry.get_parallel_safe_tools() — query registry-declared safe tools Safety tiers: - parallel_safe: read_file, web_search, search_files, etc. - path_scoped: write_file, patch (concurrent when paths don't overlap) - sequential: terminal, delegate_task, unknown tools - never_parallel: clarify (requires user interaction) 19 tests passing.
295 lines
10 KiB
Python
295 lines
10 KiB
Python
"""Batch Tool Executor — Parallel safety classification and concurrent execution.
|
|
|
|
Provides centralized classification of tool calls into parallel-safe vs sequential,
|
|
and utilities for batch execution with safety checks.
|
|
|
|
Classification tiers:
|
|
- PARALLEL_SAFE: read-only tools, no shared state (web_search, read_file, etc.)
|
|
- PATH_SCOPED: file operations that can run concurrently when paths don't overlap
|
|
- SEQUENTIAL: writes, destructive ops, terminal commands, delegation
|
|
- NEVER_PARALLEL: clarify (requires user interaction)
|
|
|
|
Usage:
|
|
from tools.batch_executor import classify_tool_calls, BatchExecutionPlan
|
|
|
|
plan = classify_tool_calls(tool_calls)
|
|
if plan.can_parallelize:
|
|
execute_concurrent(plan.parallel_batch)
|
|
execute_sequential(plan.sequential_batch)
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ── Safety Classification ──────────────────────────────────────────────────
|
|
|
|
# Tools that can ALWAYS run in parallel (read-only, no shared state)
|
|
DEFAULT_PARALLEL_SAFE = frozenset({
|
|
"ha_get_state",
|
|
"ha_list_entities",
|
|
"ha_list_services",
|
|
"read_file",
|
|
"search_files",
|
|
"session_search",
|
|
"skill_view",
|
|
"skills_list",
|
|
"vision_analyze",
|
|
"web_extract",
|
|
"web_search",
|
|
"fact_store",
|
|
"fact_search",
|
|
"session_search",
|
|
})
|
|
|
|
# File tools that can run concurrently ONLY when paths don't overlap
|
|
PATH_SCOPED_TOOLS = frozenset({"read_file", "write_file", "patch"})
|
|
|
|
# Tools that must NEVER run in parallel (require user interaction, shared mutable state)
|
|
NEVER_PARALLEL = frozenset({"clarify"})
|
|
|
|
# Patterns that indicate terminal commands may modify/delete files
|
|
DESTRUCTIVE_PATTERNS = re.compile(
|
|
r"""(?:^|\s|&&|\|\||;|`)(?:
|
|
rm\s|rmdir\s|
|
|
mv\s|
|
|
sed\s+-i|
|
|
truncate\s|
|
|
dd\s|
|
|
shred\s|
|
|
git\s+(?:reset|clean|checkout)\s
|
|
)""",
|
|
re.VERBOSE,
|
|
)
|
|
|
|
# Output redirects that overwrite files (> but not >>)
|
|
REDIRECT_OVERWRITE = re.compile(r'[^>]>[^>]|^>[^>]')
|
|
|
|
|
|
def is_destructive_command(cmd: str) -> bool:
|
|
"""Check if a terminal command modifies/deletes files."""
|
|
if not cmd:
|
|
return False
|
|
if DESTRUCTIVE_PATTERNS.search(cmd):
|
|
return True
|
|
if REDIRECT_OVERWRITE.search(cmd):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _paths_overlap(path1: Path, path2: Path) -> bool:
|
|
"""Check if two paths could conflict (one is ancestor of the other)."""
|
|
try:
|
|
path1 = path1.resolve()
|
|
path2 = path2.resolve()
|
|
return path1 == path2 or path1 in path2.parents or path2 in path1.parents
|
|
except Exception:
|
|
return True # conservative: assume overlap
|
|
|
|
|
|
def _extract_path(tool_name: str, args: dict) -> Optional[Path]:
|
|
"""Extract the target path from tool arguments for path-scoped tools."""
|
|
if tool_name not in PATH_SCOPED_TOOLS:
|
|
return None
|
|
raw_path = args.get("path")
|
|
if not isinstance(raw_path, str) or not raw_path.strip():
|
|
return None
|
|
try:
|
|
return Path(raw_path).expanduser().resolve()
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
# ── Classification ─────────────────────────────────────────────────────────
|
|
|
|
@dataclass
|
|
class ToolCallClassification:
|
|
"""Classification result for a single tool call."""
|
|
tool_name: str
|
|
args: dict
|
|
tool_call: Any # the original tool_call object
|
|
tier: str # "parallel_safe", "path_scoped", "sequential", "never_parallel"
|
|
reason: str = ""
|
|
|
|
|
|
@dataclass
|
|
class BatchExecutionPlan:
|
|
"""Plan for executing a batch of tool calls."""
|
|
classifications: List[ToolCallClassification] = field(default_factory=list)
|
|
parallel_batch: List[ToolCallClassification] = field(default_factory=list)
|
|
sequential_batch: List[ToolCallClassification] = field(default_factory=list)
|
|
|
|
@property
|
|
def can_parallelize(self) -> bool:
|
|
return len(self.parallel_batch) > 1
|
|
|
|
@property
|
|
def total(self) -> int:
|
|
return len(self.classifications)
|
|
|
|
|
|
def classify_single_tool_call(
|
|
tool_call: Any,
|
|
extra_parallel_safe: Set[str] = None,
|
|
) -> ToolCallClassification:
|
|
"""Classify a single tool call into its safety tier."""
|
|
tool_name = tool_call.function.name
|
|
try:
|
|
args = json.loads(tool_call.function.arguments)
|
|
except Exception:
|
|
return ToolCallClassification(
|
|
tool_name=tool_name, args={}, tool_call=tool_call,
|
|
tier="sequential", reason="Could not parse arguments"
|
|
)
|
|
|
|
if not isinstance(args, dict):
|
|
return ToolCallClassification(
|
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
|
tier="sequential", reason="Non-dict arguments"
|
|
)
|
|
|
|
# Check never-parallel
|
|
if tool_name in NEVER_PARALLEL:
|
|
return ToolCallClassification(
|
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
|
tier="never_parallel", reason="Requires user interaction"
|
|
)
|
|
|
|
# Check parallel-safe FIRST (before path_scoped) so read_file/search_files
|
|
# get classified as parallel_safe even though they have paths
|
|
parallel_safe_set = DEFAULT_PARALLEL_SAFE
|
|
if extra_parallel_safe:
|
|
parallel_safe_set = parallel_safe_set | extra_parallel_safe
|
|
|
|
if tool_name in parallel_safe_set:
|
|
return ToolCallClassification(
|
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
|
tier="parallel_safe", reason="Read-only, no shared state"
|
|
)
|
|
|
|
# Check terminal commands for destructive operations
|
|
if tool_name == "terminal":
|
|
cmd = args.get("command", "")
|
|
if is_destructive_command(cmd):
|
|
return ToolCallClassification(
|
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
|
tier="sequential", reason=f"Destructive command: {cmd[:50]}"
|
|
)
|
|
return ToolCallClassification(
|
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
|
tier="sequential", reason="Terminal command (conservative)"
|
|
)
|
|
|
|
# Check path-scoped tools (write_file, patch — not read_file which is parallel_safe)
|
|
if tool_name in PATH_SCOPED_TOOLS:
|
|
path = _extract_path(tool_name, args)
|
|
if path:
|
|
return ToolCallClassification(
|
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
|
tier="path_scoped", reason=f"Path: {path}"
|
|
)
|
|
return ToolCallClassification(
|
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
|
tier="sequential", reason="Path-scoped but no path found"
|
|
)
|
|
|
|
# Default: sequential (conservative)
|
|
return ToolCallClassification(
|
|
tool_name=tool_name, args=args, tool_call=tool_call,
|
|
tier="sequential", reason="Not classified as parallel-safe"
|
|
)
|
|
|
|
|
|
def classify_tool_calls(
|
|
tool_calls: list,
|
|
extra_parallel_safe: Set[str] = None,
|
|
) -> BatchExecutionPlan:
|
|
"""Classify a batch of tool calls and produce an execution plan."""
|
|
plan = BatchExecutionPlan()
|
|
|
|
reserved_paths: List[Path] = []
|
|
|
|
for tc in tool_calls:
|
|
classification = classify_single_tool_call(tc, extra_parallel_safe)
|
|
plan.classifications.append(classification)
|
|
|
|
if classification.tier == "never_parallel":
|
|
plan.sequential_batch.append(classification)
|
|
continue
|
|
|
|
if classification.tier == "sequential":
|
|
plan.sequential_batch.append(classification)
|
|
continue
|
|
|
|
if classification.tier == "path_scoped":
|
|
path = _extract_path(classification.tool_name, classification.args)
|
|
if path is None:
|
|
classification.tier = "sequential"
|
|
classification.reason = "Path extraction failed"
|
|
plan.sequential_batch.append(classification)
|
|
continue
|
|
|
|
# Check for path conflicts with already-scheduled parallel calls
|
|
conflict = any(_paths_overlap(path, existing) for existing in reserved_paths)
|
|
if conflict:
|
|
classification.tier = "sequential"
|
|
classification.reason = f"Path conflict: {path}"
|
|
plan.sequential_batch.append(classification)
|
|
else:
|
|
reserved_paths.append(path)
|
|
plan.parallel_batch.append(classification)
|
|
continue
|
|
|
|
if classification.tier == "parallel_safe":
|
|
plan.parallel_batch.append(classification)
|
|
continue
|
|
|
|
# Fallback
|
|
plan.sequential_batch.append(classification)
|
|
|
|
return plan
|
|
|
|
|
|
# ── Concurrent Execution ───────────────────────────────────────────────────
|
|
|
|
def execute_parallel_batch(
|
|
batch: List[ToolCallClassification],
|
|
invoke_fn: Callable,
|
|
max_workers: int = 8,
|
|
) -> List[Tuple[str, str]]:
|
|
"""Execute parallel-safe tool calls concurrently.
|
|
|
|
Args:
|
|
batch: List of classified tool calls (parallel_safe or path_scoped)
|
|
invoke_fn: Function(tool_name, args) -> result_string
|
|
max_workers: Max concurrent threads
|
|
|
|
Returns:
|
|
List of (tool_call_id, result_string) tuples
|
|
"""
|
|
results = []
|
|
|
|
with ThreadPoolExecutor(max_workers=min(max_workers, len(batch))) as executor:
|
|
future_to_tc = {}
|
|
for tc in batch:
|
|
future = executor.submit(invoke_fn, tc.tool_name, tc.args)
|
|
future_to_tc[future] = tc
|
|
|
|
for future in as_completed(future_to_tc):
|
|
tc = future_to_tc[future]
|
|
try:
|
|
result = future.result()
|
|
except Exception as e:
|
|
result = json.dumps({"error": str(e)})
|
|
tool_call_id = getattr(tc.tool_call, "id", None) or ""
|
|
results.append((tool_call_id, result))
|
|
|
|
return results
|