Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Failing after 51s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 55s
Tests / e2e (pull_request) Successful in 4m34s
Tests / test (pull_request) Failing after 56m41s
78 lines
2.6 KiB
Python
78 lines
2.6 KiB
Python
"""Tests for batch tool execution (#749)."""
|
|
|
|
import pytest
|
|
from tools.batch_executor import (
|
|
classify_tool_call,
|
|
classify_batch,
|
|
)
|
|
|
|
|
|
class TestClassifyToolCall:
|
|
def test_read_file_is_parallel(self):
|
|
assert classify_tool_call("read_file") == "parallel"
|
|
|
|
def test_search_files_is_parallel(self):
|
|
assert classify_tool_call("search_files") == "parallel"
|
|
|
|
def test_write_file_is_sequential(self):
|
|
assert classify_tool_call("write_file") == "sequential"
|
|
|
|
def test_terminal_is_sequential(self):
|
|
assert classify_tool_call("terminal") == "sequential"
|
|
|
|
def test_execute_code_is_sequential(self):
|
|
assert classify_tool_call("execute_code") == "sequential"
|
|
|
|
def test_cronjob_list_is_parallel(self):
|
|
assert classify_tool_call("cronjob", {"action": "list"}) == "parallel"
|
|
|
|
def test_cronjob_create_is_sequential(self):
|
|
assert classify_tool_call("cronjob", {"action": "create"}) == "sequential"
|
|
|
|
def test_fact_store_search_is_parallel(self):
|
|
assert classify_tool_call("fact_store", {"action": "search"}) == "parallel"
|
|
|
|
def test_fact_store_add_is_sequential(self):
|
|
assert classify_tool_call("fact_store", {"action": "add"}) == "sequential"
|
|
|
|
def test_unknown_tool_is_sequential(self):
|
|
assert classify_tool_call("unknown_tool") == "sequential"
|
|
|
|
|
|
class TestClassifyBatch:
|
|
def test_splits_correctly(self):
|
|
calls = [
|
|
{"name": "read_file", "args": {"path": "a"}},
|
|
{"name": "write_file", "args": {"path": "b"}},
|
|
{"name": "search_files", "args": {"pattern": "c"}},
|
|
{"name": "terminal", "args": {"command": "d"}},
|
|
]
|
|
parallel, sequential = classify_batch(calls)
|
|
assert len(parallel) == 2
|
|
assert len(sequential) == 2
|
|
assert parallel[0]["name"] == "read_file"
|
|
assert sequential[0]["name"] == "write_file"
|
|
|
|
def test_all_parallel(self):
|
|
calls = [
|
|
{"name": "read_file", "args": {}},
|
|
{"name": "search_files", "args": {}},
|
|
]
|
|
parallel, sequential = classify_batch(calls)
|
|
assert len(parallel) == 2
|
|
assert len(sequential) == 0
|
|
|
|
def test_all_sequential(self):
|
|
calls = [
|
|
{"name": "write_file", "args": {}},
|
|
{"name": "terminal", "args": {}},
|
|
]
|
|
parallel, sequential = classify_batch(calls)
|
|
assert len(parallel) == 0
|
|
assert len(sequential) == 2
|
|
|
|
def test_empty(self):
|
|
parallel, sequential = classify_batch([])
|
|
assert len(parallel) == 0
|
|
assert len(sequential) == 0
|