diff --git a/tests/test_batch_executor.py b/tests/test_batch_executor.py new file mode 100644 index 000000000..0e7bd5eda --- /dev/null +++ b/tests/test_batch_executor.py @@ -0,0 +1,136 @@ +"""Tests for batch tool execution — Issue #749.""" +import asyncio +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from tools.batch_executor import ( + ToolSafety, ToolCall, BatchResult, + classify_tool_safety, classify_calls, + execute_batch_sync, get_tool_safety_report +) + + +class TestClassification: + def test_parallel_safe_read(self): + assert classify_tool_safety("file_read") == ToolSafety.PARALLEL_SAFE + + def test_sequential_write(self): + assert classify_tool_safety("file_write") == ToolSafety.SEQUENTIAL + + def test_destructive_terminal(self): + assert classify_tool_safety("terminal") == ToolSafety.DESTRUCTIVE + + def test_unknown_defaults_sequential(self): + assert classify_tool_safety("unknown_tool") == ToolSafety.SEQUENTIAL + + def test_prefix_match(self): + assert classify_tool_safety("file_read_special") == ToolSafety.PARALLEL_SAFE + + +class TestClassifyCalls: + def test_classifies_multiple(self): + calls = [ + {"name": "file_read", "arguments": "{}"}, + {"name": "file_write", "arguments": "{}"}, + {"name": "terminal", "arguments": "{}"}, + ] + result = classify_calls(calls) + assert len(result) == 3 + assert result[0].safety == ToolSafety.PARALLEL_SAFE + assert result[1].safety == ToolSafety.SEQUENTIAL + assert result[2].safety == ToolSafety.DESTRUCTIVE + + +class TestBatchExecution: + def test_parallel_execution(self): + """Parallel-safe calls should execute faster than sequential.""" + import time + + def slow_executor(name, args): + time.sleep(0.1) + return f"result_{name}" + + calls = [ + {"name": "file_read", "arguments": "{}"}, + {"name": "file_search", "arguments": "{}"}, + {"name": "web_search", "arguments": "{}"}, + ] + + start = time.time() + result = execute_batch_sync(calls, slow_executor) + duration = time.time() - start + + # Should be faster than 0.3s (3 * 0.1) since parallel + assert duration < 0.35 # parallel: ~0.1s + overhead < 0.35s + assert result.parallel_count == 3 + assert len(result.errors) == 0 + + def test_sequential_execution(self): + """Sequential calls should execute one at a time.""" + import time + + def slow_executor(name, args): + time.sleep(0.05) + return f"result_{name}" + + calls = [ + {"name": "file_write", "arguments": "{}"}, + {"name": "file_patch", "arguments": "{}"}, + ] + + start = time.time() + result = execute_batch_sync(calls, slow_executor) + duration = time.time() - start + + # Should take at least 0.1s (2 * 0.05) since sequential + assert duration >= 0.1 + assert result.sequential_count == 2 + + def test_mixed_execution(self): + """Mixed calls: parallel first, then sequential.""" + calls = [ + {"name": "file_read", "arguments": "{}"}, + {"name": "file_write", "arguments": "{}"}, + {"name": "web_search", "arguments": "{}"}, + ] + + def executor(name, args): + return f"result_{name}" + + result = execute_batch_sync(calls, executor) + assert result.parallel_count == 2 + assert result.sequential_count == 1 + assert len(result.errors) == 0 + + def test_error_handling(self): + """Errors in one call shouldn't stop others.""" + def failing_executor(name, args): + if name == "file_write": + raise Exception("Write failed") + return "ok" + + calls = [ + {"name": "file_read", "arguments": "{}"}, + {"name": "file_write", "arguments": "{}"}, + ] + + result = execute_batch_sync(calls, failing_executor) + assert len(result.errors) == 1 + assert "file_write" in result.errors[0] + + +class TestSafetyReport: + def test_report_format(self): + calls = [ + ToolCall(name="file_read", args={}, safety=ToolSafety.PARALLEL_SAFE, duration=0.1), + ToolCall(name="file_write", args={}, safety=ToolSafety.SEQUENTIAL, duration=0.2), + ] + report = get_tool_safety_report(calls) + assert "Parallel-safe: 1" in report + assert "Sequential: 1" in report + + +if __name__ == "__main__": + import pytest + pytest.main([__file__, "-v"]) diff --git a/tools/batch_executor.py b/tools/batch_executor.py new file mode 100644 index 000000000..1be507f8c --- /dev/null +++ b/tools/batch_executor.py @@ -0,0 +1,280 @@ +"""Batch tool execution with parallel safety checks. + +Classifies tool calls as parallel-safe vs sequential and executes +parallel-safe calls concurrently while keeping destructive ops serialized. + +Issue #749: feat: batch tool execution with parallel safety checks +""" + +import asyncio +import logging +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +class ToolSafety(Enum): + """Safety classification for tool calls.""" + PARALLEL_SAFE = "parallel_safe" # Can run concurrently + SEQUENTIAL = "sequential" # Must run one at a time + DESTRUCTIVE = "destructive" # Destructive, needs approval + + +# Tool safety classifications +_TOOL_SAFETY: Dict[str, ToolSafety] = { + # Parallel-safe: reads, searches, non-destructive + "file_read": ToolSafety.PARALLEL_SAFE, + "file_search": ToolSafety.PARALLEL_SAFE, + "web_search": ToolSafety.PARALLEL_SAFE, + "web_extract": ToolSafety.PARALLEL_SAFE, + "browser_snapshot": ToolSafety.PARALLEL_SAFE, + "browser_vision": ToolSafety.PARALLEL_SAFE, + "browser_get_images": ToolSafety.PARALLEL_SAFE, + "skill_view": ToolSafety.PARALLEL_SAFE, + "memory_search": ToolSafety.PARALLEL_SAFE, + "memory_recall": ToolSafety.PARALLEL_SAFE, + "session_search": ToolSafety.PARALLEL_SAFE, + + # Sequential: writes, edits, state changes + "file_write": ToolSafety.SEQUENTIAL, + "file_patch": ToolSafety.SEQUENTIAL, + "file_append": ToolSafety.SEQUENTIAL, + "browser_navigate": ToolSafety.SEQUENTIAL, + "browser_click": ToolSafety.SEQUENTIAL, + "browser_type": ToolSafety.SEQUENTIAL, + "browser_scroll": ToolSafety.SEQUENTIAL, + "memory_store": ToolSafety.SEQUENTIAL, + "memory_update": ToolSafety.SEQUENTIAL, + "cronjob": ToolSafety.SEQUENTIAL, + "send_message": ToolSafety.SEQUENTIAL, + + # Destructive: needs approval + "terminal": ToolSafety.DESTRUCTIVE, + "execute_code": ToolSafety.DESTRUCTIVE, + "browser_execute_js": ToolSafety.DESTRUCTIVE, + "delegate_task": ToolSafety.DESTRUCTIVE, +} + + +@dataclass +class ToolCall: + """A single tool call with metadata.""" + name: str + args: Dict[str, Any] + call_id: str = "" + safety: ToolSafety = ToolSafety.SEQUENTIAL + result: Optional[Any] = None + error: Optional[str] = None + duration: float = 0.0 + started_at: float = 0.0 + completed_at: float = 0.0 + + +@dataclass +class BatchResult: + """Result of batch tool execution.""" + calls: List[ToolCall] = field(default_factory=list) + parallel_count: int = 0 + sequential_count: int = 0 + total_duration: float = 0.0 + errors: List[str] = field(default_factory=list) + + +def classify_tool_safety(tool_name: str) -> ToolSafety: + """Classify a tool call's safety level.""" + # Check exact match first + if tool_name in _TOOL_SAFETY: + return _TOOL_SAFETY[tool_name] + + # Check prefix matches + for pattern, safety in _TOOL_SAFETY.items(): + if tool_name.startswith(pattern): + return safety + + # Default to sequential for unknown tools + return ToolSafety.SEQUENTIAL + + +def classify_calls(tool_calls: List[Dict[str, Any]]) -> List[ToolCall]: + """Classify a list of tool calls by safety level.""" + calls = [] + for i, tc in enumerate(tool_calls): + name = tc.get("name", tc.get("function", {}).get("name", "")) + args = tc.get("arguments", tc.get("function", {}).get("arguments", {})) + if isinstance(args, str): + import json + try: + args = json.loads(args) + except Exception: + args = {} + + call_id = tc.get("id", f"call_{i}") + safety = classify_tool_safety(name) + + calls.append(ToolCall( + name=name, + args=args, + call_id=call_id, + safety=safety, + )) + + return calls + + +async def execute_parallel( + calls: List[ToolCall], + executor: Callable[[str, Dict[str, Any]], Any], +) -> List[ToolCall]: + """Execute parallel-safe calls concurrently.""" + async def run_call(call: ToolCall) -> ToolCall: + call.started_at = time.time() + try: + # Run in thread pool to avoid blocking + loop = asyncio.get_event_loop() + result = await loop.run_in_executor( + None, + lambda: executor(call.name, call.args), + ) + call.result = result + except Exception as e: + call.error = str(e) + logger.error(f"Parallel call {call.name} failed: {e}") + finally: + call.completed_at = time.time() + call.duration = call.completed_at - call.started_at + return call + + # Execute all parallel-safe calls concurrently + tasks = [run_call(call) for call in calls] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Handle exceptions from gather + processed = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + calls[i].error = str(result) + calls[i].completed_at = time.time() + calls[i].duration = calls[i].completed_at - calls[i].started_at + processed.append(calls[i]) + else: + processed.append(result) + + return processed + + +async def execute_sequential( + calls: List[ToolCall], + executor: Callable[[str, Dict[str, Any]], Any], +) -> List[ToolCall]: + """Execute sequential/destructive calls one at a time.""" + for call in calls: + call.started_at = time.time() + try: + result = executor(call.name, call.args) + call.result = result + except Exception as e: + call.error = str(e) + logger.error(f"Sequential call {call.name} failed: {e}") + finally: + call.completed_at = time.time() + call.duration = call.completed_at - call.started_at + + return calls + + +async def execute_batch( + tool_calls: List[Dict[str, Any]], + executor: Callable[[str, Dict[str, Any]], Any], + max_parallel: int = 5, +) -> BatchResult: + """Execute a batch of tool calls with parallel safety checks. + + Args: + tool_calls: List of tool call dicts (OpenAI format) + executor: Function to execute a single tool call (name, args) -> result + max_parallel: Maximum concurrent parallel calls + + Returns: + BatchResult with all call results and timing info + """ + start_time = time.time() + + # Classify all calls + calls = classify_calls(tool_calls) + + # Split by safety level + parallel_calls = [c for c in calls if c.safety == ToolSafety.PARALLEL_SAFE] + sequential_calls = [c for c in calls if c.safety != ToolSafety.PARALLEL_SAFE] + + result = BatchResult( + calls=calls, + parallel_count=len(parallel_calls), + sequential_count=len(sequential_calls), + ) + + # Execute parallel calls concurrently + if parallel_calls: + logger.info(f"Executing {len(parallel_calls)} parallel-safe calls concurrently") + + # Batch into chunks of max_parallel + for i in range(0, len(parallel_calls), max_parallel): + chunk = parallel_calls[i:i + max_parallel] + await execute_parallel(chunk, executor) + + # Execute sequential calls one at a time + if sequential_calls: + logger.info(f"Executing {len(sequential_calls)} sequential calls") + await execute_sequential(sequential_calls, executor) + + # Collect errors + for call in calls: + if call.error: + result.errors.append(f"{call.name}: {call.error}") + + result.total_duration = time.time() - start_time + + return result + + +def execute_batch_sync( + tool_calls: List[Dict[str, Any]], + executor: Callable[[str, Dict[str, Any]], Any], + max_parallel: int = 5, +) -> BatchResult: + """Synchronous wrapper for execute_batch.""" + return asyncio.run(execute_batch(tool_calls, executor, max_parallel)) + + +def get_tool_safety_report(calls: List[ToolCall]) -> str: + """Generate a human-readable safety report.""" + parallel = [c for c in calls if c.safety == ToolSafety.PARALLEL_SAFE] + sequential = [c for c in calls if c.safety == ToolSafety.SEQUENTIAL] + destructive = [c for c in calls if c.safety == ToolSafety.DESTRUCTIVE] + + lines = ["Tool Safety Report:"] + lines.append(f" Parallel-safe: {len(parallel)}") + lines.append(f" Sequential: {len(sequential)}") + lines.append(f" Destructive: {len(destructive)}") + + if parallel: + lines.append("\nParallel-safe calls:") + for c in parallel: + status = "✓" if not c.error else "✗" + lines.append(f" {status} {c.name} ({c.duration:.2f}s)") + + if sequential: + lines.append("\nSequential calls:") + for c in sequential: + status = "✓" if not c.error else "✗" + lines.append(f" {status} {c.name} ({c.duration:.2f}s)") + + if destructive: + lines.append("\nDestructive calls:") + for c in destructive: + status = "✓" if not c.error else "✗" + lines.append(f" {status} {c.name} ({c.duration:.2f}s)") + + return "\n".join(lines)