diff --git a/tests/test_batch_executor.py b/tests/test_batch_executor.py new file mode 100644 index 000000000..9b19e5b7a --- /dev/null +++ b/tests/test_batch_executor.py @@ -0,0 +1,150 @@ +"""Tests for batch tool execution safety classification.""" +import json +import pytest +from unittest.mock import MagicMock + + +def _make_tool_call(name: str, args: dict) -> MagicMock: + """Create a mock tool call object.""" + tc = MagicMock() + tc.function.name = name + tc.function.arguments = json.dumps(args) + tc.id = f"call_{name}_1" + return tc + + +class TestClassification: + def test_parallel_safe_read_file(self): + from tools.batch_executor import classify_single_tool_call + tc = _make_tool_call("read_file", {"path": "README.md"}) + result = classify_single_tool_call(tc) + assert result.tier == "parallel_safe" + + def test_parallel_safe_web_search(self): + from tools.batch_executor import classify_single_tool_call + tc = _make_tool_call("web_search", {"query": "test"}) + result = classify_single_tool_call(tc) + assert result.tier == "parallel_safe" + + def test_parallel_safe_search_files(self): + from tools.batch_executor import classify_single_tool_call + tc = _make_tool_call("search_files", {"pattern": "test"}) + result = classify_single_tool_call(tc) + assert result.tier == "parallel_safe" + + def test_never_parallel_clarify(self): + from tools.batch_executor import classify_single_tool_call + tc = _make_tool_call("clarify", {"question": "test"}) + result = classify_single_tool_call(tc) + assert result.tier == "never_parallel" + + def test_terminal_is_sequential(self): + from tools.batch_executor import classify_single_tool_call + tc = _make_tool_call("terminal", {"command": "ls -la"}) + result = classify_single_tool_call(tc) + assert result.tier == "sequential" + + def test_terminal_destructive_rm(self): + from tools.batch_executor import classify_single_tool_call + tc = _make_tool_call("terminal", {"command": "rm -rf /tmp/test"}) + result = classify_single_tool_call(tc) + assert result.tier == "sequential" + assert "Destructive" in result.reason + + def test_write_file_is_path_scoped(self): + from tools.batch_executor import classify_single_tool_call + tc = _make_tool_call("write_file", {"path": "/tmp/test.txt", "content": "hello"}) + result = classify_single_tool_call(tc) + assert result.tier == "path_scoped" + + def test_delegate_is_sequential(self): + from tools.batch_executor import classify_single_tool_call + tc = _make_tool_call("delegate_task", {"goal": "test"}) + result = classify_single_tool_call(tc) + assert result.tier == "sequential" + + def test_unknown_tool_is_sequential(self): + from tools.batch_executor import classify_single_tool_call + tc = _make_tool_call("some_unknown_tool", {"arg": "val"}) + result = classify_single_tool_call(tc) + assert result.tier == "sequential" + + +class TestBatchClassification: + def test_all_parallel_stays_parallel(self): + from tools.batch_executor import classify_tool_calls + tcs = [ + _make_tool_call("read_file", {"path": f"file{i}.txt"}) + for i in range(5) + ] + plan = classify_tool_calls(tcs) + assert plan.can_parallelize + assert len(plan.parallel_batch) == 5 + assert len(plan.sequential_batch) == 0 + + def test_mixed_batch(self): + from tools.batch_executor import classify_tool_calls + tcs = [ + _make_tool_call("read_file", {"path": "a.txt"}), + _make_tool_call("terminal", {"command": "ls"}), + _make_tool_call("web_search", {"query": "test"}), + _make_tool_call("delegate_task", {"goal": "test"}), + ] + plan = classify_tool_calls(tcs) + # read_file + web_search should be parallel (both parallel_safe) + # terminal + delegate_task should be sequential + assert len(plan.parallel_batch) >= 2 + assert len(plan.sequential_batch) >= 2 + + def test_clarify_blocks_all(self): + from tools.batch_executor import classify_tool_calls + tcs = [ + _make_tool_call("read_file", {"path": "a.txt"}), + _make_tool_call("clarify", {"question": "which one?"}), + _make_tool_call("web_search", {"query": "test"}), + ] + plan = classify_tool_calls(tcs) + clarify_in_seq = any(c.tool_name == "clarify" for c in plan.sequential_batch) + assert clarify_in_seq + + def test_overlapping_paths_sequential(self): + from tools.batch_executor import classify_tool_calls + tcs = [ + _make_tool_call("write_file", {"path": "/tmp/test/a.txt", "content": "hello"}), + _make_tool_call("patch", {"path": "/tmp/test/a.txt", "old_string": "a", "new_string": "b"}), + ] + plan = classify_tool_calls(tcs) + # write_file and patch on SAME file -> conflict -> one must be sequential + assert len(plan.sequential_batch) >= 1 + + +class TestDestructiveCommands: + def test_rm_flagged(self): + from tools.batch_executor import is_destructive_command + assert is_destructive_command("rm -rf /tmp") + assert is_destructive_command("rm file.txt") + + def test_mv_flagged(self): + from tools.batch_executor import is_destructive_command + assert is_destructive_command("mv old new") + + def test_sed_i_flagged(self): + from tools.batch_executor import is_destructive_command + assert is_destructive_command("sed -i 's/a/b/g' file") + + def test_redirect_overwrite_flagged(self): + from tools.batch_executor import is_destructive_command + assert is_destructive_command("echo test > file.txt") + + def test_safe_commands_not_flagged(self): + from tools.batch_executor import is_destructive_command + assert not is_destructive_command("ls -la") + assert not is_destructive_command("cat file.txt") + assert not is_destructive_command("echo test >> file.txt") # append is safe + + +class TestRegistryIntegration: + def test_parallel_safe_in_registry(self): + from tools.registry import registry + safe = registry.get_parallel_safe_tools() + assert isinstance(safe, set) diff --git a/tools/batch_executor.py b/tools/batch_executor.py new file mode 100644 index 000000000..79f66e5ea --- /dev/null +++ b/tools/batch_executor.py @@ -0,0 +1,294 @@ +"""Batch Tool Executor — Parallel safety classification and concurrent execution. + +Provides centralized classification of tool calls into parallel-safe vs sequential, +and utilities for batch execution with safety checks. + +Classification tiers: + - PARALLEL_SAFE: read-only tools, no shared state (web_search, read_file, etc.) + - PATH_SCOPED: file operations that can run concurrently when paths don't overlap + - SEQUENTIAL: writes, destructive ops, terminal commands, delegation + - NEVER_PARALLEL: clarify (requires user interaction) + +Usage: + from tools.batch_executor import classify_tool_calls, BatchExecutionPlan + + plan = classify_tool_calls(tool_calls) + if plan.can_parallelize: + execute_concurrent(plan.parallel_batch) + execute_sequential(plan.sequential_batch) +""" + +import json +import logging +import re +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Set, Tuple + +logger = logging.getLogger(__name__) + + +# ── Safety Classification ────────────────────────────────────────────────── + +# Tools that can ALWAYS run in parallel (read-only, no shared state) +DEFAULT_PARALLEL_SAFE = frozenset({ + "ha_get_state", + "ha_list_entities", + "ha_list_services", + "read_file", + "search_files", + "session_search", + "skill_view", + "skills_list", + "vision_analyze", + "web_extract", + "web_search", + "fact_store", + "fact_search", + "session_search", +}) + +# File tools that can run concurrently ONLY when paths don't overlap +PATH_SCOPED_TOOLS = frozenset({"read_file", "write_file", "patch"}) + +# Tools that must NEVER run in parallel (require user interaction, shared mutable state) +NEVER_PARALLEL = frozenset({"clarify"}) + +# Patterns that indicate terminal commands may modify/delete files +DESTRUCTIVE_PATTERNS = re.compile( + r"""(?:^|\s|&&|\|\||;|`)(?: + rm\s|rmdir\s| + mv\s| + sed\s+-i| + truncate\s| + dd\s| + shred\s| + git\s+(?:reset|clean|checkout)\s + )""", + re.VERBOSE, +) + +# Output redirects that overwrite files (> but not >>) +REDIRECT_OVERWRITE = re.compile(r'[^>]>[^>]|^>[^>]') + + +def is_destructive_command(cmd: str) -> bool: + """Check if a terminal command modifies/deletes files.""" + if not cmd: + return False + if DESTRUCTIVE_PATTERNS.search(cmd): + return True + if REDIRECT_OVERWRITE.search(cmd): + return True + return False + + +def _paths_overlap(path1: Path, path2: Path) -> bool: + """Check if two paths could conflict (one is ancestor of the other).""" + try: + path1 = path1.resolve() + path2 = path2.resolve() + return path1 == path2 or path1 in path2.parents or path2 in path1.parents + except Exception: + return True # conservative: assume overlap + + +def _extract_path(tool_name: str, args: dict) -> Optional[Path]: + """Extract the target path from tool arguments for path-scoped tools.""" + if tool_name not in PATH_SCOPED_TOOLS: + return None + raw_path = args.get("path") + if not isinstance(raw_path, str) or not raw_path.strip(): + return None + try: + return Path(raw_path).expanduser().resolve() + except Exception: + return None + + +# ── Classification ───────────────────────────────────────────────────────── + +@dataclass +class ToolCallClassification: + """Classification result for a single tool call.""" + tool_name: str + args: dict + tool_call: Any # the original tool_call object + tier: str # "parallel_safe", "path_scoped", "sequential", "never_parallel" + reason: str = "" + + +@dataclass +class BatchExecutionPlan: + """Plan for executing a batch of tool calls.""" + classifications: List[ToolCallClassification] = field(default_factory=list) + parallel_batch: List[ToolCallClassification] = field(default_factory=list) + sequential_batch: List[ToolCallClassification] = field(default_factory=list) + + @property + def can_parallelize(self) -> bool: + return len(self.parallel_batch) > 1 + + @property + def total(self) -> int: + return len(self.classifications) + + +def classify_single_tool_call( + tool_call: Any, + extra_parallel_safe: Set[str] = None, +) -> ToolCallClassification: + """Classify a single tool call into its safety tier.""" + tool_name = tool_call.function.name + try: + args = json.loads(tool_call.function.arguments) + except Exception: + return ToolCallClassification( + tool_name=tool_name, args={}, tool_call=tool_call, + tier="sequential", reason="Could not parse arguments" + ) + + if not isinstance(args, dict): + return ToolCallClassification( + tool_name=tool_name, args=args, tool_call=tool_call, + tier="sequential", reason="Non-dict arguments" + ) + + # Check never-parallel + if tool_name in NEVER_PARALLEL: + return ToolCallClassification( + tool_name=tool_name, args=args, tool_call=tool_call, + tier="never_parallel", reason="Requires user interaction" + ) + + # Check parallel-safe FIRST (before path_scoped) so read_file/search_files + # get classified as parallel_safe even though they have paths + parallel_safe_set = DEFAULT_PARALLEL_SAFE + if extra_parallel_safe: + parallel_safe_set = parallel_safe_set | extra_parallel_safe + + if tool_name in parallel_safe_set: + return ToolCallClassification( + tool_name=tool_name, args=args, tool_call=tool_call, + tier="parallel_safe", reason="Read-only, no shared state" + ) + + # Check terminal commands for destructive operations + if tool_name == "terminal": + cmd = args.get("command", "") + if is_destructive_command(cmd): + return ToolCallClassification( + tool_name=tool_name, args=args, tool_call=tool_call, + tier="sequential", reason=f"Destructive command: {cmd[:50]}" + ) + return ToolCallClassification( + tool_name=tool_name, args=args, tool_call=tool_call, + tier="sequential", reason="Terminal command (conservative)" + ) + + # Check path-scoped tools (write_file, patch — not read_file which is parallel_safe) + if tool_name in PATH_SCOPED_TOOLS: + path = _extract_path(tool_name, args) + if path: + return ToolCallClassification( + tool_name=tool_name, args=args, tool_call=tool_call, + tier="path_scoped", reason=f"Path: {path}" + ) + return ToolCallClassification( + tool_name=tool_name, args=args, tool_call=tool_call, + tier="sequential", reason="Path-scoped but no path found" + ) + + # Default: sequential (conservative) + return ToolCallClassification( + tool_name=tool_name, args=args, tool_call=tool_call, + tier="sequential", reason="Not classified as parallel-safe" + ) + + +def classify_tool_calls( + tool_calls: list, + extra_parallel_safe: Set[str] = None, +) -> BatchExecutionPlan: + """Classify a batch of tool calls and produce an execution plan.""" + plan = BatchExecutionPlan() + + reserved_paths: List[Path] = [] + + for tc in tool_calls: + classification = classify_single_tool_call(tc, extra_parallel_safe) + plan.classifications.append(classification) + + if classification.tier == "never_parallel": + plan.sequential_batch.append(classification) + continue + + if classification.tier == "sequential": + plan.sequential_batch.append(classification) + continue + + if classification.tier == "path_scoped": + path = _extract_path(classification.tool_name, classification.args) + if path is None: + classification.tier = "sequential" + classification.reason = "Path extraction failed" + plan.sequential_batch.append(classification) + continue + + # Check for path conflicts with already-scheduled parallel calls + conflict = any(_paths_overlap(path, existing) for existing in reserved_paths) + if conflict: + classification.tier = "sequential" + classification.reason = f"Path conflict: {path}" + plan.sequential_batch.append(classification) + else: + reserved_paths.append(path) + plan.parallel_batch.append(classification) + continue + + if classification.tier == "parallel_safe": + plan.parallel_batch.append(classification) + continue + + # Fallback + plan.sequential_batch.append(classification) + + return plan + + +# ── Concurrent Execution ─────────────────────────────────────────────────── + +def execute_parallel_batch( + batch: List[ToolCallClassification], + invoke_fn: Callable, + max_workers: int = 8, +) -> List[Tuple[str, str]]: + """Execute parallel-safe tool calls concurrently. + + Args: + batch: List of classified tool calls (parallel_safe or path_scoped) + invoke_fn: Function(tool_name, args) -> result_string + max_workers: Max concurrent threads + + Returns: + List of (tool_call_id, result_string) tuples + """ + results = [] + + with ThreadPoolExecutor(max_workers=min(max_workers, len(batch))) as executor: + future_to_tc = {} + for tc in batch: + future = executor.submit(invoke_fn, tc.tool_name, tc.args) + future_to_tc[future] = tc + + for future in as_completed(future_to_tc): + tc = future_to_tc[future] + try: + result = future.result() + except Exception as e: + result = json.dumps({"error": str(e)}) + tool_call_id = getattr(tc.tool_call, "id", None) or "" + results.append((tool_call_id, result)) + + return results diff --git a/tools/registry.py b/tools/registry.py index e6d554e2b..66a9f26a4 100644 --- a/tools/registry.py +++ b/tools/registry.py @@ -79,12 +79,12 @@ class ToolEntry: __slots__ = ( "name", "toolset", "schema", "handler", "check_fn", "requires_env", "is_async", "description", "emoji", - "max_result_size_chars", + "max_result_size_chars", "parallel_safe", ) def __init__(self, name, toolset, schema, handler, check_fn, requires_env, is_async, description, emoji, - max_result_size_chars=None): + max_result_size_chars=None, parallel_safe=False): self.name = name self.toolset = toolset self.schema = schema @@ -95,6 +95,7 @@ class ToolEntry: self.description = description self.emoji = emoji self.max_result_size_chars = max_result_size_chars + self.parallel_safe = parallel_safe class ToolRegistry: @@ -185,6 +186,7 @@ class ToolRegistry: description: str = "", emoji: str = "", max_result_size_chars: int | float | None = None, + parallel_safe: bool = False, ): """Register a tool. Called at module-import time by each tool file.""" with self._lock: @@ -222,6 +224,7 @@ class ToolRegistry: description=description or schema.get("description", ""), emoji=emoji, max_result_size_chars=max_result_size_chars, + parallel_safe=parallel_safe, ) if check_fn and toolset not in self._toolset_checks: self._toolset_checks[toolset] = check_fn @@ -322,6 +325,11 @@ class ToolRegistry: from tools.budget_config import DEFAULT_RESULT_SIZE_CHARS return DEFAULT_RESULT_SIZE_CHARS + def get_parallel_safe_tools(self) -> Set[str]: + """Return names of tools marked as parallel_safe.""" + with self._lock: + return {name for name, entry in self._tools.items() if entry.parallel_safe} + def get_all_tool_names(self) -> List[str]: """Return sorted list of all registered tool names.""" return sorted(entry.name for entry in self._snapshot_entries())