264 lines
9.4 KiB
Python
264 lines
9.4 KiB
Python
|
|
"""Unit tests for AIAgent pre/post-LLM-call guardrails.
|
||
|
|
|
||
|
|
Covers three static methods on AIAgent (inspired by PR #1321 — @alireza78a):
|
||
|
|
- _sanitize_api_messages() — Phase 1: orphaned tool pair repair
|
||
|
|
- _cap_delegate_task_calls() — Phase 2a: subagent concurrency limit
|
||
|
|
- _deduplicate_tool_calls() — Phase 2b: identical call deduplication
|
||
|
|
"""
|
||
|
|
|
||
|
|
import types
|
||
|
|
|
||
|
|
from run_agent import AIAgent
|
||
|
|
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Helpers
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
def make_tc(name: str, arguments: str = "{}") -> types.SimpleNamespace:
|
||
|
|
"""Create a minimal tool_call SimpleNamespace mirroring the OpenAI SDK object."""
|
||
|
|
tc = types.SimpleNamespace()
|
||
|
|
tc.function = types.SimpleNamespace(name=name, arguments=arguments)
|
||
|
|
return tc
|
||
|
|
|
||
|
|
|
||
|
|
def tool_result(call_id: str, content: str = "ok") -> dict:
|
||
|
|
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
||
|
|
|
||
|
|
|
||
|
|
def assistant_dict_call(call_id: str, name: str = "terminal") -> dict:
|
||
|
|
"""Dict-style tool_call (as stored in message history)."""
|
||
|
|
return {"id": call_id, "function": {"name": name, "arguments": "{}"}}
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Phase 1 — _sanitize_api_messages
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestSanitizeApiMessages:
|
||
|
|
|
||
|
|
def test_orphaned_result_removed(self):
|
||
|
|
msgs = [
|
||
|
|
{"role": "assistant", "tool_calls": [assistant_dict_call("c1")]},
|
||
|
|
tool_result("c1"),
|
||
|
|
tool_result("c_ORPHAN"),
|
||
|
|
]
|
||
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
||
|
|
assert len(out) == 2
|
||
|
|
assert all(m.get("tool_call_id") != "c_ORPHAN" for m in out)
|
||
|
|
|
||
|
|
def test_orphaned_call_gets_stub_result(self):
|
||
|
|
msgs = [
|
||
|
|
{"role": "assistant", "tool_calls": [assistant_dict_call("c2")]},
|
||
|
|
]
|
||
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
||
|
|
assert len(out) == 2
|
||
|
|
stub = out[1]
|
||
|
|
assert stub["role"] == "tool"
|
||
|
|
assert stub["tool_call_id"] == "c2"
|
||
|
|
assert stub["content"]
|
||
|
|
|
||
|
|
def test_clean_messages_pass_through(self):
|
||
|
|
msgs = [
|
||
|
|
{"role": "user", "content": "hello"},
|
||
|
|
{"role": "assistant", "tool_calls": [assistant_dict_call("c3")]},
|
||
|
|
tool_result("c3"),
|
||
|
|
{"role": "assistant", "content": "done"},
|
||
|
|
]
|
||
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
||
|
|
assert out == msgs
|
||
|
|
|
||
|
|
def test_mixed_orphaned_result_and_orphaned_call(self):
|
||
|
|
msgs = [
|
||
|
|
{"role": "assistant", "tool_calls": [
|
||
|
|
assistant_dict_call("c4"),
|
||
|
|
assistant_dict_call("c5"),
|
||
|
|
]},
|
||
|
|
tool_result("c4"),
|
||
|
|
tool_result("c_DANGLING"),
|
||
|
|
]
|
||
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
||
|
|
ids = [m.get("tool_call_id") for m in out if m.get("role") == "tool"]
|
||
|
|
assert "c_DANGLING" not in ids
|
||
|
|
assert "c4" in ids
|
||
|
|
assert "c5" in ids
|
||
|
|
|
||
|
|
def test_empty_list_is_safe(self):
|
||
|
|
assert AIAgent._sanitize_api_messages([]) == []
|
||
|
|
|
||
|
|
def test_no_tool_messages(self):
|
||
|
|
msgs = [
|
||
|
|
{"role": "user", "content": "hi"},
|
||
|
|
{"role": "assistant", "content": "hello"},
|
||
|
|
]
|
||
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
||
|
|
assert out == msgs
|
||
|
|
|
||
|
|
def test_sdk_object_tool_calls(self):
|
||
|
|
tc_obj = types.SimpleNamespace(id="c6", function=types.SimpleNamespace(
|
||
|
|
name="terminal", arguments="{}"
|
||
|
|
))
|
||
|
|
msgs = [
|
||
|
|
{"role": "assistant", "tool_calls": [tc_obj]},
|
||
|
|
]
|
||
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
||
|
|
assert len(out) == 2
|
||
|
|
assert out[1]["tool_call_id"] == "c6"
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Phase 2a — _cap_delegate_task_calls
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestCapDelegateTaskCalls:
|
||
|
|
|
||
|
|
def test_excess_delegates_truncated(self):
|
||
|
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
|
||
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
||
|
|
delegate_count = sum(1 for tc in out if tc.function.name == "delegate_task")
|
||
|
|
assert delegate_count == MAX_CONCURRENT_CHILDREN
|
||
|
|
|
||
|
|
def test_non_delegate_calls_preserved(self):
|
||
|
|
tcs = (
|
||
|
|
[make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 1)]
|
||
|
|
+ [make_tc("terminal"), make_tc("web_search")]
|
||
|
|
)
|
||
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
||
|
|
names = [tc.function.name for tc in out]
|
||
|
|
assert "terminal" in names
|
||
|
|
assert "web_search" in names
|
||
|
|
|
||
|
|
def test_at_limit_passes_through(self):
|
||
|
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN)]
|
||
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
||
|
|
assert out is tcs
|
||
|
|
|
||
|
|
def test_below_limit_passes_through(self):
|
||
|
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN - 1)]
|
||
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
||
|
|
assert out is tcs
|
||
|
|
|
||
|
|
def test_no_delegate_calls_unchanged(self):
|
||
|
|
tcs = [make_tc("terminal"), make_tc("web_search")]
|
||
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
||
|
|
assert out is tcs
|
||
|
|
|
||
|
|
def test_empty_list_safe(self):
|
||
|
|
assert AIAgent._cap_delegate_task_calls([]) == []
|
||
|
|
|
||
|
|
def test_original_list_not_mutated(self):
|
||
|
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
|
||
|
|
original_len = len(tcs)
|
||
|
|
AIAgent._cap_delegate_task_calls(tcs)
|
||
|
|
assert len(tcs) == original_len
|
||
|
|
|
||
|
|
def test_interleaved_order_preserved(self):
|
||
|
|
delegates = [make_tc("delegate_task", f'{{"task":"{i}"}}')
|
||
|
|
for i in range(MAX_CONCURRENT_CHILDREN + 1)]
|
||
|
|
t1 = make_tc("terminal", '{"cmd":"ls"}')
|
||
|
|
w1 = make_tc("web_search", '{"q":"x"}')
|
||
|
|
tcs = [delegates[0], t1, delegates[1], w1] + delegates[2:]
|
||
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
||
|
|
expected = [delegates[0], t1, delegates[1], w1] + delegates[2:MAX_CONCURRENT_CHILDREN]
|
||
|
|
assert len(out) == len(expected)
|
||
|
|
for i, (actual, exp) in enumerate(zip(out, expected)):
|
||
|
|
assert actual is exp, f"mismatch at index {i}"
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Phase 2b — _deduplicate_tool_calls
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestDeduplicateToolCalls:
|
||
|
|
|
||
|
|
def test_duplicate_pair_deduplicated(self):
|
||
|
|
tcs = [
|
||
|
|
make_tc("web_search", '{"query":"foo"}'),
|
||
|
|
make_tc("web_search", '{"query":"foo"}'),
|
||
|
|
]
|
||
|
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
||
|
|
assert len(out) == 1
|
||
|
|
|
||
|
|
def test_multiple_duplicates(self):
|
||
|
|
tcs = [
|
||
|
|
make_tc("web_search", '{"q":"a"}'),
|
||
|
|
make_tc("web_search", '{"q":"a"}'),
|
||
|
|
make_tc("terminal", '{"cmd":"ls"}'),
|
||
|
|
make_tc("terminal", '{"cmd":"ls"}'),
|
||
|
|
make_tc("terminal", '{"cmd":"pwd"}'),
|
||
|
|
]
|
||
|
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
||
|
|
assert len(out) == 3
|
||
|
|
|
||
|
|
def test_same_tool_different_args_kept(self):
|
||
|
|
tcs = [
|
||
|
|
make_tc("terminal", '{"cmd":"ls"}'),
|
||
|
|
make_tc("terminal", '{"cmd":"pwd"}'),
|
||
|
|
]
|
||
|
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
||
|
|
assert out is tcs
|
||
|
|
|
||
|
|
def test_different_tools_same_args_kept(self):
|
||
|
|
tcs = [
|
||
|
|
make_tc("tool_a", '{"x":1}'),
|
||
|
|
make_tc("tool_b", '{"x":1}'),
|
||
|
|
]
|
||
|
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
||
|
|
assert out is tcs
|
||
|
|
|
||
|
|
def test_clean_list_unchanged(self):
|
||
|
|
tcs = [
|
||
|
|
make_tc("web_search", '{"q":"x"}'),
|
||
|
|
make_tc("terminal", '{"cmd":"ls"}'),
|
||
|
|
]
|
||
|
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
||
|
|
assert out is tcs
|
||
|
|
|
||
|
|
def test_empty_list_safe(self):
|
||
|
|
assert AIAgent._deduplicate_tool_calls([]) == []
|
||
|
|
|
||
|
|
def test_first_occurrence_kept(self):
|
||
|
|
tc1 = make_tc("terminal", '{"cmd":"ls"}')
|
||
|
|
tc2 = make_tc("terminal", '{"cmd":"ls"}')
|
||
|
|
out = AIAgent._deduplicate_tool_calls([tc1, tc2])
|
||
|
|
assert len(out) == 1
|
||
|
|
assert out[0] is tc1
|
||
|
|
|
||
|
|
def test_original_list_not_mutated(self):
|
||
|
|
tcs = [
|
||
|
|
make_tc("web_search", '{"q":"dup"}'),
|
||
|
|
make_tc("web_search", '{"q":"dup"}'),
|
||
|
|
]
|
||
|
|
original_len = len(tcs)
|
||
|
|
AIAgent._deduplicate_tool_calls(tcs)
|
||
|
|
assert len(tcs) == original_len
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# _get_tool_call_id_static
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestGetToolCallIdStatic:
|
||
|
|
|
||
|
|
def test_dict_with_valid_id(self):
|
||
|
|
assert AIAgent._get_tool_call_id_static({"id": "call_123"}) == "call_123"
|
||
|
|
|
||
|
|
def test_dict_with_none_id(self):
|
||
|
|
assert AIAgent._get_tool_call_id_static({"id": None}) == ""
|
||
|
|
|
||
|
|
def test_dict_without_id_key(self):
|
||
|
|
assert AIAgent._get_tool_call_id_static({"function": {}}) == ""
|
||
|
|
|
||
|
|
def test_object_with_valid_id(self):
|
||
|
|
tc = types.SimpleNamespace(id="call_456")
|
||
|
|
assert AIAgent._get_tool_call_id_static(tc) == "call_456"
|
||
|
|
|
||
|
|
def test_object_with_none_id(self):
|
||
|
|
tc = types.SimpleNamespace(id=None)
|
||
|
|
assert AIAgent._get_tool_call_id_static(tc) == ""
|
||
|
|
|
||
|
|
def test_object_without_id_attr(self):
|
||
|
|
tc = types.SimpleNamespace()
|
||
|
|
assert AIAgent._get_tool_call_id_static(tc) == ""
|