"""Tests for batch tool execution — Issue #749.""" import asyncio import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from tools.batch_executor import ( ToolSafety, ToolCall, BatchResult, classify_tool_safety, classify_calls, execute_batch_sync, get_tool_safety_report ) class TestClassification: def test_parallel_safe_read(self): assert classify_tool_safety("file_read") == ToolSafety.PARALLEL_SAFE def test_sequential_write(self): assert classify_tool_safety("file_write") == ToolSafety.SEQUENTIAL def test_destructive_terminal(self): assert classify_tool_safety("terminal") == ToolSafety.DESTRUCTIVE def test_unknown_defaults_sequential(self): assert classify_tool_safety("unknown_tool") == ToolSafety.SEQUENTIAL def test_prefix_match(self): assert classify_tool_safety("file_read_special") == ToolSafety.PARALLEL_SAFE class TestClassifyCalls: def test_classifies_multiple(self): calls = [ {"name": "file_read", "arguments": "{}"}, {"name": "file_write", "arguments": "{}"}, {"name": "terminal", "arguments": "{}"}, ] result = classify_calls(calls) assert len(result) == 3 assert result[0].safety == ToolSafety.PARALLEL_SAFE assert result[1].safety == ToolSafety.SEQUENTIAL assert result[2].safety == ToolSafety.DESTRUCTIVE class TestBatchExecution: def test_parallel_execution(self): """Parallel-safe calls should execute faster than sequential.""" import time def slow_executor(name, args): time.sleep(0.1) return f"result_{name}" calls = [ {"name": "file_read", "arguments": "{}"}, {"name": "file_search", "arguments": "{}"}, {"name": "web_search", "arguments": "{}"}, ] start = time.time() result = execute_batch_sync(calls, slow_executor) duration = time.time() - start # Should be faster than 0.3s (3 * 0.1) since parallel assert duration < 0.35 # parallel: ~0.1s + overhead < 0.35s assert result.parallel_count == 3 assert len(result.errors) == 0 def test_sequential_execution(self): """Sequential calls should execute one at a time.""" import time def slow_executor(name, args): time.sleep(0.05) return f"result_{name}" calls = [ {"name": "file_write", "arguments": "{}"}, {"name": "file_patch", "arguments": "{}"}, ] start = time.time() result = execute_batch_sync(calls, slow_executor) duration = time.time() - start # Should take at least 0.1s (2 * 0.05) since sequential assert duration >= 0.1 assert result.sequential_count == 2 def test_mixed_execution(self): """Mixed calls: parallel first, then sequential.""" calls = [ {"name": "file_read", "arguments": "{}"}, {"name": "file_write", "arguments": "{}"}, {"name": "web_search", "arguments": "{}"}, ] def executor(name, args): return f"result_{name}" result = execute_batch_sync(calls, executor) assert result.parallel_count == 2 assert result.sequential_count == 1 assert len(result.errors) == 0 def test_error_handling(self): """Errors in one call shouldn't stop others.""" def failing_executor(name, args): if name == "file_write": raise Exception("Write failed") return "ok" calls = [ {"name": "file_read", "arguments": "{}"}, {"name": "file_write", "arguments": "{}"}, ] result = execute_batch_sync(calls, failing_executor) assert len(result.errors) == 1 assert "file_write" in result.errors[0] class TestSafetyReport: def test_report_format(self): calls = [ ToolCall(name="file_read", args={}, safety=ToolSafety.PARALLEL_SAFE, duration=0.1), ToolCall(name="file_write", args={}, safety=ToolSafety.SEQUENTIAL, duration=0.2), ] report = get_tool_safety_report(calls) assert "Parallel-safe: 1" in report assert "Sequential: 1" in report if __name__ == "__main__": import pytest pytest.main([__file__, "-v"])