* refactor: re-architect tests to mirror the codebase
* Update tests.yml
* fix: add missing tool_error imports after registry refactor
* fix(tests): replace patch.dict with monkeypatch to prevent env var leaks under xdist
patch.dict(os.environ) can leak TERMINAL_ENV across xdist workers,
causing test_code_execution tests to hit the Modal remote path.
* fix(tests): fix update_check and telegram xdist failures
- test_update_check: replace patch("hermes_cli.banner.os.getenv") with
monkeypatch.setenv("HERMES_HOME") — banner.py no longer imports os
directly, it uses get_hermes_home() from hermes_constants.
- test_telegram_conflict/approval_buttons: provide real exception classes
for telegram.error mock (NetworkError, TimedOut, BadRequest) so the
except clause in connect() doesn't fail with "catching classes that do
not inherit from BaseException" when xdist pollutes sys.modules.
* fix(tests): accept unavailable_models kwarg in _prompt_model_selection mock
264 lines
9.4 KiB
Python
264 lines
9.4 KiB
Python
"""Unit tests for AIAgent pre/post-LLM-call guardrails.
|
|
|
|
Covers three static methods on AIAgent (inspired by PR #1321 — @alireza78a):
|
|
- _sanitize_api_messages() — Phase 1: orphaned tool pair repair
|
|
- _cap_delegate_task_calls() — Phase 2a: subagent concurrency limit
|
|
- _deduplicate_tool_calls() — Phase 2b: identical call deduplication
|
|
"""
|
|
|
|
import types
|
|
|
|
from run_agent import AIAgent
|
|
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def make_tc(name: str, arguments: str = "{}") -> types.SimpleNamespace:
|
|
"""Create a minimal tool_call SimpleNamespace mirroring the OpenAI SDK object."""
|
|
tc = types.SimpleNamespace()
|
|
tc.function = types.SimpleNamespace(name=name, arguments=arguments)
|
|
return tc
|
|
|
|
|
|
def tool_result(call_id: str, content: str = "ok") -> dict:
|
|
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
|
|
|
|
|
def assistant_dict_call(call_id: str, name: str = "terminal") -> dict:
|
|
"""Dict-style tool_call (as stored in message history)."""
|
|
return {"id": call_id, "function": {"name": name, "arguments": "{}"}}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Phase 1 — _sanitize_api_messages
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSanitizeApiMessages:
|
|
|
|
def test_orphaned_result_removed(self):
|
|
msgs = [
|
|
{"role": "assistant", "tool_calls": [assistant_dict_call("c1")]},
|
|
tool_result("c1"),
|
|
tool_result("c_ORPHAN"),
|
|
]
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
|
assert len(out) == 2
|
|
assert all(m.get("tool_call_id") != "c_ORPHAN" for m in out)
|
|
|
|
def test_orphaned_call_gets_stub_result(self):
|
|
msgs = [
|
|
{"role": "assistant", "tool_calls": [assistant_dict_call("c2")]},
|
|
]
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
|
assert len(out) == 2
|
|
stub = out[1]
|
|
assert stub["role"] == "tool"
|
|
assert stub["tool_call_id"] == "c2"
|
|
assert stub["content"]
|
|
|
|
def test_clean_messages_pass_through(self):
|
|
msgs = [
|
|
{"role": "user", "content": "hello"},
|
|
{"role": "assistant", "tool_calls": [assistant_dict_call("c3")]},
|
|
tool_result("c3"),
|
|
{"role": "assistant", "content": "done"},
|
|
]
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
|
assert out == msgs
|
|
|
|
def test_mixed_orphaned_result_and_orphaned_call(self):
|
|
msgs = [
|
|
{"role": "assistant", "tool_calls": [
|
|
assistant_dict_call("c4"),
|
|
assistant_dict_call("c5"),
|
|
]},
|
|
tool_result("c4"),
|
|
tool_result("c_DANGLING"),
|
|
]
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
|
ids = [m.get("tool_call_id") for m in out if m.get("role") == "tool"]
|
|
assert "c_DANGLING" not in ids
|
|
assert "c4" in ids
|
|
assert "c5" in ids
|
|
|
|
def test_empty_list_is_safe(self):
|
|
assert AIAgent._sanitize_api_messages([]) == []
|
|
|
|
def test_no_tool_messages(self):
|
|
msgs = [
|
|
{"role": "user", "content": "hi"},
|
|
{"role": "assistant", "content": "hello"},
|
|
]
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
|
assert out == msgs
|
|
|
|
def test_sdk_object_tool_calls(self):
|
|
tc_obj = types.SimpleNamespace(id="c6", function=types.SimpleNamespace(
|
|
name="terminal", arguments="{}"
|
|
))
|
|
msgs = [
|
|
{"role": "assistant", "tool_calls": [tc_obj]},
|
|
]
|
|
out = AIAgent._sanitize_api_messages(msgs)
|
|
assert len(out) == 2
|
|
assert out[1]["tool_call_id"] == "c6"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Phase 2a — _cap_delegate_task_calls
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestCapDelegateTaskCalls:
|
|
|
|
def test_excess_delegates_truncated(self):
|
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
|
delegate_count = sum(1 for tc in out if tc.function.name == "delegate_task")
|
|
assert delegate_count == MAX_CONCURRENT_CHILDREN
|
|
|
|
def test_non_delegate_calls_preserved(self):
|
|
tcs = (
|
|
[make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 1)]
|
|
+ [make_tc("terminal"), make_tc("web_search")]
|
|
)
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
|
names = [tc.function.name for tc in out]
|
|
assert "terminal" in names
|
|
assert "web_search" in names
|
|
|
|
def test_at_limit_passes_through(self):
|
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN)]
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
|
assert out is tcs
|
|
|
|
def test_below_limit_passes_through(self):
|
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN - 1)]
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
|
assert out is tcs
|
|
|
|
def test_no_delegate_calls_unchanged(self):
|
|
tcs = [make_tc("terminal"), make_tc("web_search")]
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
|
assert out is tcs
|
|
|
|
def test_empty_list_safe(self):
|
|
assert AIAgent._cap_delegate_task_calls([]) == []
|
|
|
|
def test_original_list_not_mutated(self):
|
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
|
|
original_len = len(tcs)
|
|
AIAgent._cap_delegate_task_calls(tcs)
|
|
assert len(tcs) == original_len
|
|
|
|
def test_interleaved_order_preserved(self):
|
|
delegates = [make_tc("delegate_task", f'{{"task":"{i}"}}')
|
|
for i in range(MAX_CONCURRENT_CHILDREN + 1)]
|
|
t1 = make_tc("terminal", '{"cmd":"ls"}')
|
|
w1 = make_tc("web_search", '{"q":"x"}')
|
|
tcs = [delegates[0], t1, delegates[1], w1] + delegates[2:]
|
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
|
expected = [delegates[0], t1, delegates[1], w1] + delegates[2:MAX_CONCURRENT_CHILDREN]
|
|
assert len(out) == len(expected)
|
|
for i, (actual, exp) in enumerate(zip(out, expected)):
|
|
assert actual is exp, f"mismatch at index {i}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Phase 2b — _deduplicate_tool_calls
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestDeduplicateToolCalls:
|
|
|
|
def test_duplicate_pair_deduplicated(self):
|
|
tcs = [
|
|
make_tc("web_search", '{"query":"foo"}'),
|
|
make_tc("web_search", '{"query":"foo"}'),
|
|
]
|
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
|
assert len(out) == 1
|
|
|
|
def test_multiple_duplicates(self):
|
|
tcs = [
|
|
make_tc("web_search", '{"q":"a"}'),
|
|
make_tc("web_search", '{"q":"a"}'),
|
|
make_tc("terminal", '{"cmd":"ls"}'),
|
|
make_tc("terminal", '{"cmd":"ls"}'),
|
|
make_tc("terminal", '{"cmd":"pwd"}'),
|
|
]
|
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
|
assert len(out) == 3
|
|
|
|
def test_same_tool_different_args_kept(self):
|
|
tcs = [
|
|
make_tc("terminal", '{"cmd":"ls"}'),
|
|
make_tc("terminal", '{"cmd":"pwd"}'),
|
|
]
|
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
|
assert out is tcs
|
|
|
|
def test_different_tools_same_args_kept(self):
|
|
tcs = [
|
|
make_tc("tool_a", '{"x":1}'),
|
|
make_tc("tool_b", '{"x":1}'),
|
|
]
|
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
|
assert out is tcs
|
|
|
|
def test_clean_list_unchanged(self):
|
|
tcs = [
|
|
make_tc("web_search", '{"q":"x"}'),
|
|
make_tc("terminal", '{"cmd":"ls"}'),
|
|
]
|
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
|
assert out is tcs
|
|
|
|
def test_empty_list_safe(self):
|
|
assert AIAgent._deduplicate_tool_calls([]) == []
|
|
|
|
def test_first_occurrence_kept(self):
|
|
tc1 = make_tc("terminal", '{"cmd":"ls"}')
|
|
tc2 = make_tc("terminal", '{"cmd":"ls"}')
|
|
out = AIAgent._deduplicate_tool_calls([tc1, tc2])
|
|
assert len(out) == 1
|
|
assert out[0] is tc1
|
|
|
|
def test_original_list_not_mutated(self):
|
|
tcs = [
|
|
make_tc("web_search", '{"q":"dup"}'),
|
|
make_tc("web_search", '{"q":"dup"}'),
|
|
]
|
|
original_len = len(tcs)
|
|
AIAgent._deduplicate_tool_calls(tcs)
|
|
assert len(tcs) == original_len
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _get_tool_call_id_static
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestGetToolCallIdStatic:
|
|
|
|
def test_dict_with_valid_id(self):
|
|
assert AIAgent._get_tool_call_id_static({"id": "call_123"}) == "call_123"
|
|
|
|
def test_dict_with_none_id(self):
|
|
assert AIAgent._get_tool_call_id_static({"id": None}) == ""
|
|
|
|
def test_dict_without_id_key(self):
|
|
assert AIAgent._get_tool_call_id_static({"function": {}}) == ""
|
|
|
|
def test_object_with_valid_id(self):
|
|
tc = types.SimpleNamespace(id="call_456")
|
|
assert AIAgent._get_tool_call_id_static(tc) == "call_456"
|
|
|
|
def test_object_with_none_id(self):
|
|
tc = types.SimpleNamespace(id=None)
|
|
assert AIAgent._get_tool_call_id_static(tc) == ""
|
|
|
|
def test_object_without_id_attr(self):
|
|
tc = types.SimpleNamespace()
|
|
assert AIAgent._get_tool_call_id_static(tc) == ""
|