"""Batch Tool Executor — Parallel safety classification and concurrent execution. Provides centralized classification of tool calls into parallel-safe vs sequential, and utilities for batch execution with safety checks. Classification tiers: - PARALLEL_SAFE: read-only tools, no shared state (web_search, read_file, etc.) - PATH_SCOPED: file operations that can run concurrently when paths don't overlap - SEQUENTIAL: writes, destructive ops, terminal commands, delegation - NEVER_PARALLEL: clarify (requires user interaction) Usage: from tools.batch_executor import classify_tool_calls, BatchExecutionPlan plan = classify_tool_calls(tool_calls) if plan.can_parallelize: execute_concurrent(plan.parallel_batch) execute_sequential(plan.sequential_batch) """ import json import logging import re from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set, Tuple logger = logging.getLogger(__name__) # ── Safety Classification ────────────────────────────────────────────────── # Tools that can ALWAYS run in parallel (read-only, no shared state) DEFAULT_PARALLEL_SAFE = frozenset({ "ha_get_state", "ha_list_entities", "ha_list_services", "read_file", "search_files", "session_search", "skill_view", "skills_list", "vision_analyze", "web_extract", "web_search", "fact_store", "fact_search", "session_search", }) # File tools that can run concurrently ONLY when paths don't overlap PATH_SCOPED_TOOLS = frozenset({"read_file", "write_file", "patch"}) # Tools that must NEVER run in parallel (require user interaction, shared mutable state) NEVER_PARALLEL = frozenset({"clarify"}) # Patterns that indicate terminal commands may modify/delete files DESTRUCTIVE_PATTERNS = re.compile( r"""(?:^|\s|&&|\|\||;|`)(?: rm\s|rmdir\s| mv\s| sed\s+-i| truncate\s| dd\s| shred\s| git\s+(?:reset|clean|checkout)\s )""", re.VERBOSE, ) # Output redirects that overwrite files (> but not >>) REDIRECT_OVERWRITE = re.compile(r'[^>]>[^>]|^>[^>]') def is_destructive_command(cmd: str) -> bool: """Check if a terminal command modifies/deletes files.""" if not cmd: return False if DESTRUCTIVE_PATTERNS.search(cmd): return True if REDIRECT_OVERWRITE.search(cmd): return True return False def _paths_overlap(path1: Path, path2: Path) -> bool: """Check if two paths could conflict (one is ancestor of the other).""" try: path1 = path1.resolve() path2 = path2.resolve() return path1 == path2 or path1 in path2.parents or path2 in path1.parents except Exception: return True # conservative: assume overlap def _extract_path(tool_name: str, args: dict) -> Optional[Path]: """Extract the target path from tool arguments for path-scoped tools.""" if tool_name not in PATH_SCOPED_TOOLS: return None raw_path = args.get("path") if not isinstance(raw_path, str) or not raw_path.strip(): return None try: return Path(raw_path).expanduser().resolve() except Exception: return None # ── Classification ───────────────────────────────────────────────────────── @dataclass class ToolCallClassification: """Classification result for a single tool call.""" tool_name: str args: dict tool_call: Any # the original tool_call object tier: str # "parallel_safe", "path_scoped", "sequential", "never_parallel" reason: str = "" @dataclass class BatchExecutionPlan: """Plan for executing a batch of tool calls.""" classifications: List[ToolCallClassification] = field(default_factory=list) parallel_batch: List[ToolCallClassification] = field(default_factory=list) sequential_batch: List[ToolCallClassification] = field(default_factory=list) @property def can_parallelize(self) -> bool: return len(self.parallel_batch) > 1 @property def total(self) -> int: return len(self.classifications) def classify_single_tool_call( tool_call: Any, extra_parallel_safe: Set[str] = None, ) -> ToolCallClassification: """Classify a single tool call into its safety tier.""" tool_name = tool_call.function.name try: args = json.loads(tool_call.function.arguments) except Exception: return ToolCallClassification( tool_name=tool_name, args={}, tool_call=tool_call, tier="sequential", reason="Could not parse arguments" ) if not isinstance(args, dict): return ToolCallClassification( tool_name=tool_name, args=args, tool_call=tool_call, tier="sequential", reason="Non-dict arguments" ) # Check never-parallel if tool_name in NEVER_PARALLEL: return ToolCallClassification( tool_name=tool_name, args=args, tool_call=tool_call, tier="never_parallel", reason="Requires user interaction" ) # Check parallel-safe FIRST (before path_scoped) so read_file/search_files # get classified as parallel_safe even though they have paths parallel_safe_set = DEFAULT_PARALLEL_SAFE if extra_parallel_safe: parallel_safe_set = parallel_safe_set | extra_parallel_safe if tool_name in parallel_safe_set: return ToolCallClassification( tool_name=tool_name, args=args, tool_call=tool_call, tier="parallel_safe", reason="Read-only, no shared state" ) # Check terminal commands for destructive operations if tool_name == "terminal": cmd = args.get("command", "") if is_destructive_command(cmd): return ToolCallClassification( tool_name=tool_name, args=args, tool_call=tool_call, tier="sequential", reason=f"Destructive command: {cmd[:50]}" ) return ToolCallClassification( tool_name=tool_name, args=args, tool_call=tool_call, tier="sequential", reason="Terminal command (conservative)" ) # Check path-scoped tools (write_file, patch — not read_file which is parallel_safe) if tool_name in PATH_SCOPED_TOOLS: path = _extract_path(tool_name, args) if path: return ToolCallClassification( tool_name=tool_name, args=args, tool_call=tool_call, tier="path_scoped", reason=f"Path: {path}" ) return ToolCallClassification( tool_name=tool_name, args=args, tool_call=tool_call, tier="sequential", reason="Path-scoped but no path found" ) # Default: sequential (conservative) return ToolCallClassification( tool_name=tool_name, args=args, tool_call=tool_call, tier="sequential", reason="Not classified as parallel-safe" ) def classify_tool_calls( tool_calls: list, extra_parallel_safe: Set[str] = None, ) -> BatchExecutionPlan: """Classify a batch of tool calls and produce an execution plan.""" plan = BatchExecutionPlan() reserved_paths: List[Path] = [] for tc in tool_calls: classification = classify_single_tool_call(tc, extra_parallel_safe) plan.classifications.append(classification) if classification.tier == "never_parallel": plan.sequential_batch.append(classification) continue if classification.tier == "sequential": plan.sequential_batch.append(classification) continue if classification.tier == "path_scoped": path = _extract_path(classification.tool_name, classification.args) if path is None: classification.tier = "sequential" classification.reason = "Path extraction failed" plan.sequential_batch.append(classification) continue # Check for path conflicts with already-scheduled parallel calls conflict = any(_paths_overlap(path, existing) for existing in reserved_paths) if conflict: classification.tier = "sequential" classification.reason = f"Path conflict: {path}" plan.sequential_batch.append(classification) else: reserved_paths.append(path) plan.parallel_batch.append(classification) continue if classification.tier == "parallel_safe": plan.parallel_batch.append(classification) continue # Fallback plan.sequential_batch.append(classification) return plan # ── Concurrent Execution ─────────────────────────────────────────────────── def execute_parallel_batch( batch: List[ToolCallClassification], invoke_fn: Callable, max_workers: int = 8, ) -> List[Tuple[str, str]]: """Execute parallel-safe tool calls concurrently. Args: batch: List of classified tool calls (parallel_safe or path_scoped) invoke_fn: Function(tool_name, args) -> result_string max_workers: Max concurrent threads Returns: List of (tool_call_id, result_string) tuples """ results = [] with ThreadPoolExecutor(max_workers=min(max_workers, len(batch))) as executor: future_to_tc = {} for tc in batch: future = executor.submit(invoke_fn, tc.tool_name, tc.args) future_to_tc[future] = tc for future in as_completed(future_to_tc): tc = future_to_tc[future] try: result = future.result() except Exception as e: result = json.dumps({"error": str(e)}) tool_call_id = getattr(tc.tool_call, "id", None) or "" results.append((tool_call_id, result)) return results