Salvage of PR #1321 by @alireza78a (cherry-picked concept, reimplemented against current main). Phase 1 — Pre-call message sanitization: _sanitize_api_messages() now runs unconditionally before every LLM call. Previously gated on context_compressor being present, so sessions loaded from disk or running without compression could accumulate dangling tool_call/tool_result pairs causing API errors. Phase 2a — Delegate task cap: _cap_delegate_task_calls() truncates excess delegate_task calls per turn to MAX_CONCURRENT_CHILDREN. The existing cap in delegate_tool.py only limits the task array within a single call; this catches multiple separate delegate_task tool_calls in one turn. Phase 2b — Tool call deduplication: _deduplicate_tool_calls() drops duplicate (tool_name, arguments) pairs within a single turn when models stutter. All three are static methods on AIAgent, independently testable. 29 tests covering happy paths and edge cases.
264 lines
9.4 KiB
Python
264 lines
9.4 KiB
Python
"""Unit tests for AIAgent pre/post-LLM-call guardrails.
|
|
|
|
Covers three static methods on AIAgent (inspired by PR #1321 — @alireza78a):
|
|
- _sanitize_api_messages() — Phase 1: orphaned tool pair repair
|
|
- _cap_delegate_task_calls() — Phase 2a: subagent concurrency limit
|
|
- _deduplicate_tool_calls() — Phase 2b: identical call deduplication
|
|
"""
|
|
|
|
import types
|
|
|
|
from run_agent import AIAgent
|
|
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def make_tc(name: str, arguments: str = "{}") -> types.SimpleNamespace:
|
|
"""Create a minimal tool_call SimpleNamespace mirroring the OpenAI SDK object."""
|
|
tc = types.SimpleNamespace()
|
|
tc.function = types.SimpleNamespace(name=name, arguments=arguments)
|
|
return tc
|
|
|
|
|
|
def tool_result(call_id: str, content: str = "ok") -> dict:
|
|
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
|
|
|
|
|
def assistant_dict_call(call_id: str, name: str = "terminal") -> dict:
|
|
"""Dict-style tool_call (as stored in message history)."""
|
|
return {"id": call_id, "function": {"name": name, "arguments": "{}"}}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Phase 1 — _sanitize_api_messages
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSanitizeApiMessages:
|
|
|
|
def test_orphaned_result_removed(self):
|
|
msgs = [
|
|
{"role": "assistant", "tool_calls": [assistant_dict_call("c1")]},
|
|
tool_result("c1"),
|
|
tool_result("c_ORPHAN"),
|
|
]
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
|
assert len(out) == 2
|
|
assert all(m.get("tool_call_id") != "c_ORPHAN" for m in out)
|
|
|
|
def test_orphaned_call_gets_stub_result(self):
|
|
msgs = [
|
|
{"role": "assistant", "tool_calls": [assistant_dict_call("c2")]},
|
|
]
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
|
assert len(out) == 2
|
|
stub = out[1]
|
|
assert stub["role"] == "tool"
|
|
assert stub["tool_call_id"] == "c2"
|
|
assert stub["content"]
|
|
|
|
def test_clean_messages_pass_through(self):
|
|
msgs = [
|
|
{"role": "user", "content": "hello"},
|
|
{"role": "assistant", "tool_calls": [assistant_dict_call("c3")]},
|
|
tool_result("c3"),
|
|
{"role": "assistant", "content": "done"},
|
|
]
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
|
assert out == msgs
|
|
|
|
def test_mixed_orphaned_result_and_orphaned_call(self):
|
|
msgs = [
|
|
{"role": "assistant", "tool_calls": [
|
|
assistant_dict_call("c4"),
|
|
assistant_dict_call("c5"),
|
|
]},
|
|
tool_result("c4"),
|
|
tool_result("c_DANGLING"),
|
|
]
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
|
ids = [m.get("tool_call_id") for m in out if m.get("role") == "tool"]
|
|
assert "c_DANGLING" not in ids
|
|
assert "c4" in ids
|
|
assert "c5" in ids
|
|
|
|
def test_empty_list_is_safe(self):
|
|
assert AIAgent._sanitize_api_messages([]) == []
|
|
|
|
def test_no_tool_messages(self):
|
|
msgs = [
|
|
{"role": "user", "content": "hi"},
|
|
{"role": "assistant", "content": "hello"},
|
|
]
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
|
assert out == msgs
|
|
|
|
def test_sdk_object_tool_calls(self):
|
|
tc_obj = types.SimpleNamespace(id="c6", function=types.SimpleNamespace(
|
|
name="terminal", arguments="{}"
|
|
))
|
|
msgs = [
|
|
{"role": "assistant", "tool_calls": [tc_obj]},
|
|
]
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
|
assert len(out) == 2
|
|
assert out[1]["tool_call_id"] == "c6"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Phase 2a — _cap_delegate_task_calls
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestCapDelegateTaskCalls:
|
|
|
|
def test_excess_delegates_truncated(self):
|
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
|
delegate_count = sum(1 for tc in out if tc.function.name == "delegate_task")
|
|
assert delegate_count == MAX_CONCURRENT_CHILDREN
|
|
|
|
def test_non_delegate_calls_preserved(self):
|
|
tcs = (
|
|
[make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 1)]
|
|
+ [make_tc("terminal"), make_tc("web_search")]
|
|
)
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
|
names = [tc.function.name for tc in out]
|
|
assert "terminal" in names
|
|
assert "web_search" in names
|
|
|
|
def test_at_limit_passes_through(self):
|
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN)]
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
|
assert out is tcs
|
|
|
|
def test_below_limit_passes_through(self):
|
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN - 1)]
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
|
assert out is tcs
|
|
|
|
def test_no_delegate_calls_unchanged(self):
|
|
tcs = [make_tc("terminal"), make_tc("web_search")]
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
|
assert out is tcs
|
|
|
|
def test_empty_list_safe(self):
|
|
assert AIAgent._cap_delegate_task_calls([]) == []
|
|
|
|
def test_original_list_not_mutated(self):
|
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
|
|
original_len = len(tcs)
|
|
AIAgent._cap_delegate_task_calls(tcs)
|
|
assert len(tcs) == original_len
|
|
|
|
def test_interleaved_order_preserved(self):
|
|
delegates = [make_tc("delegate_task", f'{{"task":"{i}"}}')
|
|
for i in range(MAX_CONCURRENT_CHILDREN + 1)]
|
|
t1 = make_tc("terminal", '{"cmd":"ls"}')
|
|
w1 = make_tc("web_search", '{"q":"x"}')
|
|
tcs = [delegates[0], t1, delegates[1], w1] + delegates[2:]
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
|
expected = [delegates[0], t1, delegates[1], w1] + delegates[2:MAX_CONCURRENT_CHILDREN]
|
|
assert len(out) == len(expected)
|
|
for i, (actual, exp) in enumerate(zip(out, expected)):
|
|
assert actual is exp, f"mismatch at index {i}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Phase 2b — _deduplicate_tool_calls
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestDeduplicateToolCalls:
|
|
|
|
def test_duplicate_pair_deduplicated(self):
|
|
tcs = [
|
|
make_tc("web_search", '{"query":"foo"}'),
|
|
make_tc("web_search", '{"query":"foo"}'),
|
|
]
|
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
|
assert len(out) == 1
|
|
|
|
def test_multiple_duplicates(self):
|
|
tcs = [
|
|
make_tc("web_search", '{"q":"a"}'),
|
|
make_tc("web_search", '{"q":"a"}'),
|
|
make_tc("terminal", '{"cmd":"ls"}'),
|
|
make_tc("terminal", '{"cmd":"ls"}'),
|
|
make_tc("terminal", '{"cmd":"pwd"}'),
|
|
]
|
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
|
assert len(out) == 3
|
|
|
|
def test_same_tool_different_args_kept(self):
|
|
tcs = [
|
|
make_tc("terminal", '{"cmd":"ls"}'),
|
|
make_tc("terminal", '{"cmd":"pwd"}'),
|
|
]
|
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
|
assert out is tcs
|
|
|
|
def test_different_tools_same_args_kept(self):
|
|
tcs = [
|
|
make_tc("tool_a", '{"x":1}'),
|
|
make_tc("tool_b", '{"x":1}'),
|
|
]
|
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
|
assert out is tcs
|
|
|
|
def test_clean_list_unchanged(self):
|
|
tcs = [
|
|
make_tc("web_search", '{"q":"x"}'),
|
|
make_tc("terminal", '{"cmd":"ls"}'),
|
|
]
|
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
|
assert out is tcs
|
|
|
|
def test_empty_list_safe(self):
|
|
assert AIAgent._deduplicate_tool_calls([]) == []
|
|
|
|
def test_first_occurrence_kept(self):
|
|
tc1 = make_tc("terminal", '{"cmd":"ls"}')
|
|
tc2 = make_tc("terminal", '{"cmd":"ls"}')
|
|
out = AIAgent._deduplicate_tool_calls([tc1, tc2])
|
|
assert len(out) == 1
|
|
assert out[0] is tc1
|
|
|
|
def test_original_list_not_mutated(self):
|
|
tcs = [
|
|
make_tc("web_search", '{"q":"dup"}'),
|
|
make_tc("web_search", '{"q":"dup"}'),
|
|
]
|
|
original_len = len(tcs)
|
|
AIAgent._deduplicate_tool_calls(tcs)
|
|
assert len(tcs) == original_len
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _get_tool_call_id_static
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestGetToolCallIdStatic:
|
|
|
|
def test_dict_with_valid_id(self):
|
|
assert AIAgent._get_tool_call_id_static({"id": "call_123"}) == "call_123"
|
|
|
|
def test_dict_with_none_id(self):
|
|
assert AIAgent._get_tool_call_id_static({"id": None}) == ""
|
|
|
|
def test_dict_without_id_key(self):
|
|
assert AIAgent._get_tool_call_id_static({"function": {}}) == ""
|
|
|
|
def test_object_with_valid_id(self):
|
|
tc = types.SimpleNamespace(id="call_456")
|
|
assert AIAgent._get_tool_call_id_static(tc) == "call_456"
|
|
|
|
def test_object_with_none_id(self):
|
|
tc = types.SimpleNamespace(id=None)
|
|
assert AIAgent._get_tool_call_id_static(tc) == ""
|
|
|
|
def test_object_without_id_attr(self):
|
|
tc = types.SimpleNamespace()
|
|
assert AIAgent._get_tool_call_id_static(tc) == ""
|