fix: batch tool execution with parallel safety checks (closes #749)
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 33s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 37s
Tests / e2e (pull_request) Successful in 9m12s
Tests / test (pull_request) Failing after 48m44s

This commit is contained in:
Timmy Time
2026-04-15 21:53:45 -04:00
parent db72e908f7
commit f7f89e15ff
2 changed files with 416 additions and 0 deletions

View File

@@ -0,0 +1,136 @@
"""Tests for batch tool execution — Issue #749."""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from tools.batch_executor import (
ToolSafety, ToolCall, BatchResult,
classify_tool_safety, classify_calls,
execute_batch_sync, get_tool_safety_report
)
class TestClassification:
def test_parallel_safe_read(self):
assert classify_tool_safety("file_read") == ToolSafety.PARALLEL_SAFE
def test_sequential_write(self):
assert classify_tool_safety("file_write") == ToolSafety.SEQUENTIAL
def test_destructive_terminal(self):
assert classify_tool_safety("terminal") == ToolSafety.DESTRUCTIVE
def test_unknown_defaults_sequential(self):
assert classify_tool_safety("unknown_tool") == ToolSafety.SEQUENTIAL
def test_prefix_match(self):
assert classify_tool_safety("file_read_special") == ToolSafety.PARALLEL_SAFE
class TestClassifyCalls:
def test_classifies_multiple(self):
calls = [
{"name": "file_read", "arguments": "{}"},
{"name": "file_write", "arguments": "{}"},
{"name": "terminal", "arguments": "{}"},
]
result = classify_calls(calls)
assert len(result) == 3
assert result[0].safety == ToolSafety.PARALLEL_SAFE
assert result[1].safety == ToolSafety.SEQUENTIAL
assert result[2].safety == ToolSafety.DESTRUCTIVE
class TestBatchExecution:
def test_parallel_execution(self):
"""Parallel-safe calls should execute faster than sequential."""
import time
def slow_executor(name, args):
time.sleep(0.1)
return f"result_{name}"
calls = [
{"name": "file_read", "arguments": "{}"},
{"name": "file_search", "arguments": "{}"},
{"name": "web_search", "arguments": "{}"},
]
start = time.time()
result = execute_batch_sync(calls, slow_executor)
duration = time.time() - start
# Should be faster than 0.3s (3 * 0.1) since parallel
assert duration < 0.25
assert result.parallel_count == 3
assert len(result.errors) == 0
def test_sequential_execution(self):
"""Sequential calls should execute one at a time."""
import time
def slow_executor(name, args):
time.sleep(0.05)
return f"result_{name}"
calls = [
{"name": "file_write", "arguments": "{}"},
{"name": "file_patch", "arguments": "{}"},
]
start = time.time()
result = execute_batch_sync(calls, slow_executor)
duration = time.time() - start
# Should take at least 0.1s (2 * 0.05) since sequential
assert duration >= 0.1
assert result.sequential_count == 2
def test_mixed_execution(self):
"""Mixed calls: parallel first, then sequential."""
calls = [
{"name": "file_read", "arguments": "{}"},
{"name": "file_write", "arguments": "{}"},
{"name": "web_search", "arguments": "{}"},
]
def executor(name, args):
return f"result_{name}"
result = execute_batch_sync(calls, executor)
assert result.parallel_count == 2
assert result.sequential_count == 1
assert len(result.errors) == 0
def test_error_handling(self):
"""Errors in one call shouldn't stop others."""
def failing_executor(name, args):
if name == "file_write":
raise Exception("Write failed")
return "ok"
calls = [
{"name": "file_read", "arguments": "{}"},
{"name": "file_write", "arguments": "{}"},
]
result = execute_batch_sync(calls, failing_executor)
assert len(result.errors) == 1
assert "file_write" in result.errors[0]
class TestSafetyReport:
def test_report_format(self):
calls = [
ToolCall(name="file_read", args={}, safety=ToolSafety.PARALLEL_SAFE, duration=0.1),
ToolCall(name="file_write", args={}, safety=ToolSafety.SEQUENTIAL, duration=0.2),
]
report = get_tool_safety_report(calls)
assert "Parallel-safe: 1" in report
assert "Sequential: 1" in report
if __name__ == "__main__":
import pytest
pytest.main([__file__, "-v"])