137 lines
4.5 KiB
Python
137 lines
4.5 KiB
Python
|
|
"""Tests for batch tool execution — Issue #749."""
|
||
|
|
import asyncio
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
|
|
||
|
|
from tools.batch_executor import (
|
||
|
|
ToolSafety, ToolCall, BatchResult,
|
||
|
|
classify_tool_safety, classify_calls,
|
||
|
|
execute_batch_sync, get_tool_safety_report
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TestClassification:
|
||
|
|
def test_parallel_safe_read(self):
|
||
|
|
assert classify_tool_safety("file_read") == ToolSafety.PARALLEL_SAFE
|
||
|
|
|
||
|
|
def test_sequential_write(self):
|
||
|
|
assert classify_tool_safety("file_write") == ToolSafety.SEQUENTIAL
|
||
|
|
|
||
|
|
def test_destructive_terminal(self):
|
||
|
|
assert classify_tool_safety("terminal") == ToolSafety.DESTRUCTIVE
|
||
|
|
|
||
|
|
def test_unknown_defaults_sequential(self):
|
||
|
|
assert classify_tool_safety("unknown_tool") == ToolSafety.SEQUENTIAL
|
||
|
|
|
||
|
|
def test_prefix_match(self):
|
||
|
|
assert classify_tool_safety("file_read_special") == ToolSafety.PARALLEL_SAFE
|
||
|
|
|
||
|
|
|
||
|
|
class TestClassifyCalls:
|
||
|
|
def test_classifies_multiple(self):
|
||
|
|
calls = [
|
||
|
|
{"name": "file_read", "arguments": "{}"},
|
||
|
|
{"name": "file_write", "arguments": "{}"},
|
||
|
|
{"name": "terminal", "arguments": "{}"},
|
||
|
|
]
|
||
|
|
result = classify_calls(calls)
|
||
|
|
assert len(result) == 3
|
||
|
|
assert result[0].safety == ToolSafety.PARALLEL_SAFE
|
||
|
|
assert result[1].safety == ToolSafety.SEQUENTIAL
|
||
|
|
assert result[2].safety == ToolSafety.DESTRUCTIVE
|
||
|
|
|
||
|
|
|
||
|
|
class TestBatchExecution:
|
||
|
|
def test_parallel_execution(self):
|
||
|
|
"""Parallel-safe calls should execute faster than sequential."""
|
||
|
|
import time
|
||
|
|
|
||
|
|
def slow_executor(name, args):
|
||
|
|
time.sleep(0.1)
|
||
|
|
return f"result_{name}"
|
||
|
|
|
||
|
|
calls = [
|
||
|
|
{"name": "file_read", "arguments": "{}"},
|
||
|
|
{"name": "file_search", "arguments": "{}"},
|
||
|
|
{"name": "web_search", "arguments": "{}"},
|
||
|
|
]
|
||
|
|
|
||
|
|
start = time.time()
|
||
|
|
result = execute_batch_sync(calls, slow_executor)
|
||
|
|
duration = time.time() - start
|
||
|
|
|
||
|
|
# Should be faster than 0.3s (3 * 0.1) since parallel
|
||
|
|
assert duration < 0.35 # parallel: ~0.1s + overhead < 0.35s
|
||
|
|
assert result.parallel_count == 3
|
||
|
|
assert len(result.errors) == 0
|
||
|
|
|
||
|
|
def test_sequential_execution(self):
|
||
|
|
"""Sequential calls should execute one at a time."""
|
||
|
|
import time
|
||
|
|
|
||
|
|
def slow_executor(name, args):
|
||
|
|
time.sleep(0.05)
|
||
|
|
return f"result_{name}"
|
||
|
|
|
||
|
|
calls = [
|
||
|
|
{"name": "file_write", "arguments": "{}"},
|
||
|
|
{"name": "file_patch", "arguments": "{}"},
|
||
|
|
]
|
||
|
|
|
||
|
|
start = time.time()
|
||
|
|
result = execute_batch_sync(calls, slow_executor)
|
||
|
|
duration = time.time() - start
|
||
|
|
|
||
|
|
# Should take at least 0.1s (2 * 0.05) since sequential
|
||
|
|
assert duration >= 0.1
|
||
|
|
assert result.sequential_count == 2
|
||
|
|
|
||
|
|
def test_mixed_execution(self):
|
||
|
|
"""Mixed calls: parallel first, then sequential."""
|
||
|
|
calls = [
|
||
|
|
{"name": "file_read", "arguments": "{}"},
|
||
|
|
{"name": "file_write", "arguments": "{}"},
|
||
|
|
{"name": "web_search", "arguments": "{}"},
|
||
|
|
]
|
||
|
|
|
||
|
|
def executor(name, args):
|
||
|
|
return f"result_{name}"
|
||
|
|
|
||
|
|
result = execute_batch_sync(calls, executor)
|
||
|
|
assert result.parallel_count == 2
|
||
|
|
assert result.sequential_count == 1
|
||
|
|
assert len(result.errors) == 0
|
||
|
|
|
||
|
|
def test_error_handling(self):
|
||
|
|
"""Errors in one call shouldn't stop others."""
|
||
|
|
def failing_executor(name, args):
|
||
|
|
if name == "file_write":
|
||
|
|
raise Exception("Write failed")
|
||
|
|
return "ok"
|
||
|
|
|
||
|
|
calls = [
|
||
|
|
{"name": "file_read", "arguments": "{}"},
|
||
|
|
{"name": "file_write", "arguments": "{}"},
|
||
|
|
]
|
||
|
|
|
||
|
|
result = execute_batch_sync(calls, failing_executor)
|
||
|
|
assert len(result.errors) == 1
|
||
|
|
assert "file_write" in result.errors[0]
|
||
|
|
|
||
|
|
|
||
|
|
class TestSafetyReport:
|
||
|
|
def test_report_format(self):
|
||
|
|
calls = [
|
||
|
|
ToolCall(name="file_read", args={}, safety=ToolSafety.PARALLEL_SAFE, duration=0.1),
|
||
|
|
ToolCall(name="file_write", args={}, safety=ToolSafety.SEQUENTIAL, duration=0.2),
|
||
|
|
]
|
||
|
|
report = get_tool_safety_report(calls)
|
||
|
|
assert "Parallel-safe: 1" in report
|
||
|
|
assert "Sequential: 1" in report
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
import pytest
|
||
|
|
pytest.main([__file__, "-v"])
|