""" Batch tool execution with parallel safety checks (#749). Classifies tool calls as parallel-safe or sequential, then executes parallel-safe calls concurrently while keeping destructive ops serialized. Safety classification: - PARALLEL-SAFE: read_file, search_files, browser_snapshot, session_search, fact_store (search/probe/list), skill_view - SEQUENTIAL: write_file, patch, terminal, execute_code, browser_click, browser_type, browser_navigate, cronjob (create/update/delete), memory (add/update/remove), skill_manage """ import asyncio import logging import time from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) # Tools that only read state — safe to parallelize PARALLEL_SAFE_TOOLS = frozenset([ "read_file", "search_files", "browser_snapshot", "browser_get_images", "browser_back", "browser_vision", "browser_console", "session_search", "fact_store", # search/probe/list are read-only; add/update are not "skill_view", "skills_list", "cronjob", # list is read-only; create/update/run are not (filtered below) "clarify", # asking questions is safe "memory", # probe/search/list are read-only "vision_analyze", ]) # Tools that modify state — must be serialized SEQUENTIAL_TOOLS = frozenset([ "write_file", "patch", "terminal", "execute_code", "browser_click", "browser_type", "browser_press", "browser_scroll", "browser_navigate", "cronjob", # create/update/run/pause/resume/remove "memory", # add/update/remove "skill_manage", "todo", "text_to_speech", "image_generate", "delegate_task", "clarify", # clarify with choices needs user input "process", ]) # Cronjob sub-actions that are read-only _CRON_READ_ONLY = frozenset(["list"]) @dataclass class BatchResult: """Result of a batch tool execution.""" results: List[Dict[str, Any]] = field(default_factory=list) parallel_count: int = 0 sequential_count: int = 0 elapsed_ms: float = 0 def classify_tool_call(tool_name: str, tool_args: Optional[Dict] = None) -> str: """Classify a tool call as 'parallel' or 'sequential'. Returns 'parallel' or 'sequential'. """ # Special cases based on sub-action if tool_name == "cronjob": action = (tool_args or {}).get("action", "") if action in _CRON_READ_ONLY: return "parallel" return "sequential" if tool_name == "fact_store": action = (tool_args or {}).get("action", "") if action in ("search", "probe", "list", "related", "reason", "contradict"): return "parallel" return "sequential" if tool_name == "memory": action = (tool_args or {}).get("action", "") if action in ("probe", "search", "list"): return "parallel" return "sequential" # Check sequential first (more restrictive) if tool_name in SEQUENTIAL_TOOLS: return "sequential" if tool_name in PARALLEL_SAFE_TOOLS: return "parallel" # Unknown tools default to sequential (safe) return "sequential" def classify_batch(tool_calls: List[Dict]) -> Tuple[List[Dict], List[Dict]]: """Split a list of tool calls into parallel-safe and sequential groups. Args: tool_calls: List of dicts with 'name' and 'args' keys Returns: (parallel_calls, sequential_calls) """ parallel = [] sequential = [] for call in tool_calls: name = call.get("name", "") args = call.get("args", {}) classification = classify_tool_call(name, args) if classification == "parallel": parallel.append(call) else: sequential.append(call) return parallel, sequential async def execute_parallel( tool_calls: List[Dict], executor: Callable, ) -> List[Dict[str, Any]]: """Execute parallel-safe tool calls concurrently. Args: tool_calls: List of tool call dicts executor: Async callable(tool_name, tool_args) -> result Returns: List of results in same order as input """ tasks = [] for call in tool_calls: task = asyncio.create_task( executor(call["name"], call.get("args", {})), name=f"tool:{call['name']}" ) tasks.append((call, task)) results = [] for call, task in tasks: try: result = await task results.append({ "tool_name": call["name"], "result": result, "parallel": True, "error": None, }) except Exception as e: logger.error("Parallel tool '%s' failed: %s", call["name"], e) results.append({ "tool_name": call["name"], "result": None, "parallel": True, "error": str(e), }) return results async def execute_sequential( tool_calls: List[Dict], executor: Callable, ) -> List[Dict[str, Any]]: """Execute sequential tool calls one at a time.""" results = [] for call in tool_calls: try: result = await executor(call["name"], call.get("args", {})) results.append({ "tool_name": call["name"], "result": result, "parallel": False, "error": None, }) except Exception as e: logger.error("Sequential tool '%s' failed: %s", call["name"], e) results.append({ "tool_name": call["name"], "result": None, "parallel": False, "error": str(e), }) return results async def execute_batch( tool_calls: List[Dict], executor: Callable, ) -> BatchResult: """Execute a batch of tool calls with parallel safety checks. 1. Classify each call as parallel-safe or sequential 2. Execute all parallel-safe calls concurrently 3. Execute sequential calls one at a time 4. Merge results in original order Args: tool_calls: List of dicts with 'name' and 'args' keys executor: Async callable(tool_name, tool_args) -> result Returns: BatchResult with all results and timing """ start = time.monotonic() parallel_calls, sequential_calls = classify_batch(tool_calls) # Execute parallel-safe calls concurrently parallel_results = [] if parallel_calls: parallel_results = await execute_parallel(parallel_calls, executor) # Execute sequential calls in order sequential_results = [] if sequential_calls: sequential_results = await execute_sequential(sequential_calls, executor) # Merge results — parallel first, then sequential (order preserved within groups) all_results = parallel_results + sequential_results elapsed = (time.monotonic() - start) * 1000 return BatchResult( results=all_results, parallel_count=len(parallel_calls), sequential_count=len(sequential_calls), elapsed_ms=elapsed, )