"""Batch tool execution with parallel safety checks. Classifies tool calls as parallel-safe vs sequential and executes parallel-safe calls concurrently while keeping destructive ops serialized. Issue #749: feat: batch tool execution with parallel safety checks """ import asyncio import logging import time from dataclasses import dataclass, field from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) class ToolSafety(Enum): """Safety classification for tool calls.""" PARALLEL_SAFE = "parallel_safe" # Can run concurrently SEQUENTIAL = "sequential" # Must run one at a time DESTRUCTIVE = "destructive" # Destructive, needs approval # Tool safety classifications _TOOL_SAFETY: Dict[str, ToolSafety] = { # Parallel-safe: reads, searches, non-destructive "file_read": ToolSafety.PARALLEL_SAFE, "file_search": ToolSafety.PARALLEL_SAFE, "web_search": ToolSafety.PARALLEL_SAFE, "web_extract": ToolSafety.PARALLEL_SAFE, "browser_snapshot": ToolSafety.PARALLEL_SAFE, "browser_vision": ToolSafety.PARALLEL_SAFE, "browser_get_images": ToolSafety.PARALLEL_SAFE, "skill_view": ToolSafety.PARALLEL_SAFE, "memory_search": ToolSafety.PARALLEL_SAFE, "memory_recall": ToolSafety.PARALLEL_SAFE, "session_search": ToolSafety.PARALLEL_SAFE, # Sequential: writes, edits, state changes "file_write": ToolSafety.SEQUENTIAL, "file_patch": ToolSafety.SEQUENTIAL, "file_append": ToolSafety.SEQUENTIAL, "browser_navigate": ToolSafety.SEQUENTIAL, "browser_click": ToolSafety.SEQUENTIAL, "browser_type": ToolSafety.SEQUENTIAL, "browser_scroll": ToolSafety.SEQUENTIAL, "memory_store": ToolSafety.SEQUENTIAL, "memory_update": ToolSafety.SEQUENTIAL, "cronjob": ToolSafety.SEQUENTIAL, "send_message": ToolSafety.SEQUENTIAL, # Destructive: needs approval "terminal": ToolSafety.DESTRUCTIVE, "execute_code": ToolSafety.DESTRUCTIVE, "browser_execute_js": ToolSafety.DESTRUCTIVE, "delegate_task": ToolSafety.DESTRUCTIVE, } @dataclass class ToolCall: """A single tool call with metadata.""" name: str args: Dict[str, Any] call_id: str = "" safety: ToolSafety = ToolSafety.SEQUENTIAL result: Optional[Any] = None error: Optional[str] = None duration: float = 0.0 started_at: float = 0.0 completed_at: float = 0.0 @dataclass class BatchResult: """Result of batch tool execution.""" calls: List[ToolCall] = field(default_factory=list) parallel_count: int = 0 sequential_count: int = 0 total_duration: float = 0.0 errors: List[str] = field(default_factory=list) def classify_tool_safety(tool_name: str) -> ToolSafety: """Classify a tool call's safety level.""" # Check exact match first if tool_name in _TOOL_SAFETY: return _TOOL_SAFETY[tool_name] # Check prefix matches for pattern, safety in _TOOL_SAFETY.items(): if tool_name.startswith(pattern): return safety # Default to sequential for unknown tools return ToolSafety.SEQUENTIAL def classify_calls(tool_calls: List[Dict[str, Any]]) -> List[ToolCall]: """Classify a list of tool calls by safety level.""" calls = [] for i, tc in enumerate(tool_calls): name = tc.get("name", tc.get("function", {}).get("name", "")) args = tc.get("arguments", tc.get("function", {}).get("arguments", {})) if isinstance(args, str): import json try: args = json.loads(args) except Exception: args = {} call_id = tc.get("id", f"call_{i}") safety = classify_tool_safety(name) calls.append(ToolCall( name=name, args=args, call_id=call_id, safety=safety, )) return calls async def execute_parallel( calls: List[ToolCall], executor: Callable[[str, Dict[str, Any]], Any], ) -> List[ToolCall]: """Execute parallel-safe calls concurrently.""" async def run_call(call: ToolCall) -> ToolCall: call.started_at = time.time() try: # Run in thread pool to avoid blocking loop = asyncio.get_event_loop() result = await loop.run_in_executor( None, lambda: executor(call.name, call.args), ) call.result = result except Exception as e: call.error = str(e) logger.error(f"Parallel call {call.name} failed: {e}") finally: call.completed_at = time.time() call.duration = call.completed_at - call.started_at return call # Execute all parallel-safe calls concurrently tasks = [run_call(call) for call in calls] results = await asyncio.gather(*tasks, return_exceptions=True) # Handle exceptions from gather processed = [] for i, result in enumerate(results): if isinstance(result, Exception): calls[i].error = str(result) calls[i].completed_at = time.time() calls[i].duration = calls[i].completed_at - calls[i].started_at processed.append(calls[i]) else: processed.append(result) return processed async def execute_sequential( calls: List[ToolCall], executor: Callable[[str, Dict[str, Any]], Any], ) -> List[ToolCall]: """Execute sequential/destructive calls one at a time.""" for call in calls: call.started_at = time.time() try: result = executor(call.name, call.args) call.result = result except Exception as e: call.error = str(e) logger.error(f"Sequential call {call.name} failed: {e}") finally: call.completed_at = time.time() call.duration = call.completed_at - call.started_at return calls async def execute_batch( tool_calls: List[Dict[str, Any]], executor: Callable[[str, Dict[str, Any]], Any], max_parallel: int = 5, ) -> BatchResult: """Execute a batch of tool calls with parallel safety checks. Args: tool_calls: List of tool call dicts (OpenAI format) executor: Function to execute a single tool call (name, args) -> result max_parallel: Maximum concurrent parallel calls Returns: BatchResult with all call results and timing info """ start_time = time.time() # Classify all calls calls = classify_calls(tool_calls) # Split by safety level parallel_calls = [c for c in calls if c.safety == ToolSafety.PARALLEL_SAFE] sequential_calls = [c for c in calls if c.safety != ToolSafety.PARALLEL_SAFE] result = BatchResult( calls=calls, parallel_count=len(parallel_calls), sequential_count=len(sequential_calls), ) # Execute parallel calls concurrently if parallel_calls: logger.info(f"Executing {len(parallel_calls)} parallel-safe calls concurrently") # Batch into chunks of max_parallel for i in range(0, len(parallel_calls), max_parallel): chunk = parallel_calls[i:i + max_parallel] await execute_parallel(chunk, executor) # Execute sequential calls one at a time if sequential_calls: logger.info(f"Executing {len(sequential_calls)} sequential calls") await execute_sequential(sequential_calls, executor) # Collect errors for call in calls: if call.error: result.errors.append(f"{call.name}: {call.error}") result.total_duration = time.time() - start_time return result def execute_batch_sync( tool_calls: List[Dict[str, Any]], executor: Callable[[str, Dict[str, Any]], Any], max_parallel: int = 5, ) -> BatchResult: """Synchronous wrapper for execute_batch.""" return asyncio.run(execute_batch(tool_calls, executor, max_parallel)) def get_tool_safety_report(calls: List[ToolCall]) -> str: """Generate a human-readable safety report.""" parallel = [c for c in calls if c.safety == ToolSafety.PARALLEL_SAFE] sequential = [c for c in calls if c.safety == ToolSafety.SEQUENTIAL] destructive = [c for c in calls if c.safety == ToolSafety.DESTRUCTIVE] lines = ["Tool Safety Report:"] lines.append(f" Parallel-safe: {len(parallel)}") lines.append(f" Sequential: {len(sequential)}") lines.append(f" Destructive: {len(destructive)}") if parallel: lines.append("\nParallel-safe calls:") for c in parallel: status = "✓" if not c.error else "✗" lines.append(f" {status} {c.name} ({c.duration:.2f}s)") if sequential: lines.append("\nSequential calls:") for c in sequential: status = "✓" if not c.error else "✗" lines.append(f" {status} {c.name} ({c.duration:.2f}s)") if destructive: lines.append("\nDestructive calls:") for c in destructive: status = "✓" if not c.error else "✗" lines.append(f" {status} {c.name} ({c.duration:.2f}s)") return "\n".join(lines)