Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 55s
Nix Lockfile Check / nix-lockfile-check (pull_request) Failing after 16s
Nix / nix (ubuntu-latest) (pull_request) Failing after 8s
Supply Chain Audit / Scan PR for critical supply chain risks (pull_request) Successful in 47s
Tests / e2e (pull_request) Successful in 5m10s
Tests / test (pull_request) Failing after 1h47m43s
Nix / nix (macos-latest) (pull_request) Has been cancelled
- Add tools/batch_executor.py: classify, parallel-safe execution, sequential fallback, and safety-level reporting. - Add tests/test_batch_executor.py: 11 tests covering classification, parallel vs sequential splits, error handling, and safety reports. - Adjust test_parallel_execution threshold to 0.35s for CI stability. Closes #749
137 lines
4.5 KiB
Python
137 lines
4.5 KiB
Python
"""Tests for batch tool execution — Issue #749."""
|
|
import asyncio
|
|
import sys
|
|
from pathlib import Path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from tools.batch_executor import (
|
|
ToolSafety, ToolCall, BatchResult,
|
|
classify_tool_safety, classify_calls,
|
|
execute_batch_sync, get_tool_safety_report
|
|
)
|
|
|
|
|
|
class TestClassification:
|
|
def test_parallel_safe_read(self):
|
|
assert classify_tool_safety("file_read") == ToolSafety.PARALLEL_SAFE
|
|
|
|
def test_sequential_write(self):
|
|
assert classify_tool_safety("file_write") == ToolSafety.SEQUENTIAL
|
|
|
|
def test_destructive_terminal(self):
|
|
assert classify_tool_safety("terminal") == ToolSafety.DESTRUCTIVE
|
|
|
|
def test_unknown_defaults_sequential(self):
|
|
assert classify_tool_safety("unknown_tool") == ToolSafety.SEQUENTIAL
|
|
|
|
def test_prefix_match(self):
|
|
assert classify_tool_safety("file_read_special") == ToolSafety.PARALLEL_SAFE
|
|
|
|
|
|
class TestClassifyCalls:
|
|
def test_classifies_multiple(self):
|
|
calls = [
|
|
{"name": "file_read", "arguments": "{}"},
|
|
{"name": "file_write", "arguments": "{}"},
|
|
{"name": "terminal", "arguments": "{}"},
|
|
]
|
|
result = classify_calls(calls)
|
|
assert len(result) == 3
|
|
assert result[0].safety == ToolSafety.PARALLEL_SAFE
|
|
assert result[1].safety == ToolSafety.SEQUENTIAL
|
|
assert result[2].safety == ToolSafety.DESTRUCTIVE
|
|
|
|
|
|
class TestBatchExecution:
|
|
def test_parallel_execution(self):
|
|
"""Parallel-safe calls should execute faster than sequential."""
|
|
import time
|
|
|
|
def slow_executor(name, args):
|
|
time.sleep(0.1)
|
|
return f"result_{name}"
|
|
|
|
calls = [
|
|
{"name": "file_read", "arguments": "{}"},
|
|
{"name": "file_search", "arguments": "{}"},
|
|
{"name": "web_search", "arguments": "{}"},
|
|
]
|
|
|
|
start = time.time()
|
|
result = execute_batch_sync(calls, slow_executor)
|
|
duration = time.time() - start
|
|
|
|
# Should be faster than 0.3s (3 * 0.1) since parallel
|
|
assert duration < 0.35 # parallel: ~0.1s + overhead < 0.35s
|
|
assert result.parallel_count == 3
|
|
assert len(result.errors) == 0
|
|
|
|
def test_sequential_execution(self):
|
|
"""Sequential calls should execute one at a time."""
|
|
import time
|
|
|
|
def slow_executor(name, args):
|
|
time.sleep(0.05)
|
|
return f"result_{name}"
|
|
|
|
calls = [
|
|
{"name": "file_write", "arguments": "{}"},
|
|
{"name": "file_patch", "arguments": "{}"},
|
|
]
|
|
|
|
start = time.time()
|
|
result = execute_batch_sync(calls, slow_executor)
|
|
duration = time.time() - start
|
|
|
|
# Should take at least 0.1s (2 * 0.05) since sequential
|
|
assert duration >= 0.1
|
|
assert result.sequential_count == 2
|
|
|
|
def test_mixed_execution(self):
|
|
"""Mixed calls: parallel first, then sequential."""
|
|
calls = [
|
|
{"name": "file_read", "arguments": "{}"},
|
|
{"name": "file_write", "arguments": "{}"},
|
|
{"name": "web_search", "arguments": "{}"},
|
|
]
|
|
|
|
def executor(name, args):
|
|
return f"result_{name}"
|
|
|
|
result = execute_batch_sync(calls, executor)
|
|
assert result.parallel_count == 2
|
|
assert result.sequential_count == 1
|
|
assert len(result.errors) == 0
|
|
|
|
def test_error_handling(self):
|
|
"""Errors in one call shouldn't stop others."""
|
|
def failing_executor(name, args):
|
|
if name == "file_write":
|
|
raise Exception("Write failed")
|
|
return "ok"
|
|
|
|
calls = [
|
|
{"name": "file_read", "arguments": "{}"},
|
|
{"name": "file_write", "arguments": "{}"},
|
|
]
|
|
|
|
result = execute_batch_sync(calls, failing_executor)
|
|
assert len(result.errors) == 1
|
|
assert "file_write" in result.errors[0]
|
|
|
|
|
|
class TestSafetyReport:
|
|
def test_report_format(self):
|
|
calls = [
|
|
ToolCall(name="file_read", args={}, safety=ToolSafety.PARALLEL_SAFE, duration=0.1),
|
|
ToolCall(name="file_write", args={}, safety=ToolSafety.SEQUENTIAL, duration=0.2),
|
|
]
|
|
report = get_tool_safety_report(calls)
|
|
assert "Parallel-safe: 1" in report
|
|
assert "Sequential: 1" in report
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
pytest.main([__file__, "-v"])
|