feat: pre-call sanitization and post-call tool guardrails (#1732)
Salvage of PR #1321 by @alireza78a (cherry-picked concept, reimplemented against current main). Phase 1 — Pre-call message sanitization: _sanitize_api_messages() now runs unconditionally before every LLM call. Previously gated on context_compressor being present, so sessions loaded from disk or running without compression could accumulate dangling tool_call/tool_result pairs causing API errors. Phase 2a — Delegate task cap: _cap_delegate_task_calls() truncates excess delegate_task calls per turn to MAX_CONCURRENT_CHILDREN. The existing cap in delegate_tool.py only limits the task array within a single call; this catches multiple separate delegate_task tool_calls in one turn. Phase 2b — Tool call deduplication: _deduplicate_tool_calls() drops duplicate (tool_name, arguments) pairs within a single turn when models stutter. All three are static methods on AIAgent, independently testable. 29 tests covering happy paths and edge cases.
This commit is contained in:
138
run_agent.py
138
run_agent.py
@@ -1957,7 +1957,124 @@ class AIAgent:
|
|||||||
prompt_parts.append(PLATFORM_HINTS[platform_key])
|
prompt_parts.append(PLATFORM_HINTS[platform_key])
|
||||||
|
|
||||||
return "\n\n".join(prompt_parts)
|
return "\n\n".join(prompt_parts)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Pre/post-call guardrails (inspired by PR #1321 — @alireza78a)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_tool_call_id_static(tc) -> str:
|
||||||
|
"""Extract call ID from a tool_call entry (dict or object)."""
|
||||||
|
if isinstance(tc, dict):
|
||||||
|
return tc.get("id", "") or ""
|
||||||
|
return getattr(tc, "id", "") or ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitize_api_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""Fix orphaned tool_call / tool_result pairs before every LLM call.
|
||||||
|
|
||||||
|
Runs unconditionally — not gated on whether the context compressor
|
||||||
|
is present — so orphans from session loading or manual message
|
||||||
|
manipulation are always caught.
|
||||||
|
"""
|
||||||
|
surviving_call_ids: set = set()
|
||||||
|
for msg in messages:
|
||||||
|
if msg.get("role") == "assistant":
|
||||||
|
for tc in msg.get("tool_calls") or []:
|
||||||
|
cid = AIAgent._get_tool_call_id_static(tc)
|
||||||
|
if cid:
|
||||||
|
surviving_call_ids.add(cid)
|
||||||
|
|
||||||
|
result_call_ids: set = set()
|
||||||
|
for msg in messages:
|
||||||
|
if msg.get("role") == "tool":
|
||||||
|
cid = msg.get("tool_call_id")
|
||||||
|
if cid:
|
||||||
|
result_call_ids.add(cid)
|
||||||
|
|
||||||
|
# 1. Drop tool results with no matching assistant call
|
||||||
|
orphaned_results = result_call_ids - surviving_call_ids
|
||||||
|
if orphaned_results:
|
||||||
|
messages = [
|
||||||
|
m for m in messages
|
||||||
|
if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
|
||||||
|
]
|
||||||
|
logger.debug(
|
||||||
|
"Pre-call sanitizer: removed %d orphaned tool result(s)",
|
||||||
|
len(orphaned_results),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Inject stub results for calls whose result was dropped
|
||||||
|
missing_results = surviving_call_ids - result_call_ids
|
||||||
|
if missing_results:
|
||||||
|
patched: List[Dict[str, Any]] = []
|
||||||
|
for msg in messages:
|
||||||
|
patched.append(msg)
|
||||||
|
if msg.get("role") == "assistant":
|
||||||
|
for tc in msg.get("tool_calls") or []:
|
||||||
|
cid = AIAgent._get_tool_call_id_static(tc)
|
||||||
|
if cid in missing_results:
|
||||||
|
patched.append({
|
||||||
|
"role": "tool",
|
||||||
|
"content": "[Result unavailable — see context summary above]",
|
||||||
|
"tool_call_id": cid,
|
||||||
|
})
|
||||||
|
messages = patched
|
||||||
|
logger.debug(
|
||||||
|
"Pre-call sanitizer: added %d stub tool result(s)",
|
||||||
|
len(missing_results),
|
||||||
|
)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cap_delegate_task_calls(tool_calls: list) -> list:
|
||||||
|
"""Truncate excess delegate_task calls to MAX_CONCURRENT_CHILDREN.
|
||||||
|
|
||||||
|
The delegate_tool caps the task list inside a single call, but the
|
||||||
|
model can emit multiple separate delegate_task tool_calls in one
|
||||||
|
turn. This truncates the excess, preserving all non-delegate calls.
|
||||||
|
|
||||||
|
Returns the original list if no truncation was needed.
|
||||||
|
"""
|
||||||
|
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
|
||||||
|
delegate_count = sum(1 for tc in tool_calls if tc.function.name == "delegate_task")
|
||||||
|
if delegate_count <= MAX_CONCURRENT_CHILDREN:
|
||||||
|
return tool_calls
|
||||||
|
kept_delegates = 0
|
||||||
|
truncated = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
if tc.function.name == "delegate_task":
|
||||||
|
if kept_delegates < MAX_CONCURRENT_CHILDREN:
|
||||||
|
truncated.append(tc)
|
||||||
|
kept_delegates += 1
|
||||||
|
else:
|
||||||
|
truncated.append(tc)
|
||||||
|
logger.warning(
|
||||||
|
"Truncated %d excess delegate_task call(s) to enforce "
|
||||||
|
"MAX_CONCURRENT_CHILDREN=%d limit",
|
||||||
|
delegate_count - MAX_CONCURRENT_CHILDREN, MAX_CONCURRENT_CHILDREN,
|
||||||
|
)
|
||||||
|
return truncated
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _deduplicate_tool_calls(tool_calls: list) -> list:
|
||||||
|
"""Remove duplicate (tool_name, arguments) pairs within a single turn.
|
||||||
|
|
||||||
|
Only the first occurrence of each unique pair is kept.
|
||||||
|
Returns the original list if no duplicates were found.
|
||||||
|
"""
|
||||||
|
seen: set = set()
|
||||||
|
unique: list = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
key = (tc.function.name, tc.function.arguments)
|
||||||
|
if key not in seen:
|
||||||
|
seen.add(key)
|
||||||
|
unique.append(tc)
|
||||||
|
else:
|
||||||
|
logger.warning("Removed duplicate tool call: %s", tc.function.name)
|
||||||
|
return unique if len(unique) < len(tool_calls) else tool_calls
|
||||||
|
|
||||||
def _repair_tool_call(self, tool_name: str) -> str | None:
|
def _repair_tool_call(self, tool_name: str) -> str | None:
|
||||||
"""Attempt to repair a mismatched tool name before aborting.
|
"""Attempt to repair a mismatched tool name before aborting.
|
||||||
|
|
||||||
@@ -4992,11 +5109,10 @@ class AIAgent:
|
|||||||
api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl)
|
api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl)
|
||||||
|
|
||||||
# Safety net: strip orphaned tool results / add stubs for missing
|
# Safety net: strip orphaned tool results / add stubs for missing
|
||||||
# results before sending to the API. The compressor handles this
|
# results before sending to the API. Runs unconditionally — not
|
||||||
# during compression, but orphans can also sneak in from session
|
# gated on context_compressor — so orphans from session loading or
|
||||||
# loading or manual message manipulation.
|
# manual message manipulation are always caught.
|
||||||
if hasattr(self, 'context_compressor') and self.context_compressor:
|
api_messages = self._sanitize_api_messages(api_messages)
|
||||||
api_messages = self.context_compressor._sanitize_tool_pairs(api_messages)
|
|
||||||
|
|
||||||
# Calculate approximate request size for logging
|
# Calculate approximate request size for logging
|
||||||
total_chars = sum(len(str(msg)) for msg in api_messages)
|
total_chars = sum(len(str(msg)) for msg in api_messages)
|
||||||
@@ -6026,7 +6142,15 @@ class AIAgent:
|
|||||||
|
|
||||||
# Reset retry counter on successful JSON validation
|
# Reset retry counter on successful JSON validation
|
||||||
self._invalid_json_retries = 0
|
self._invalid_json_retries = 0
|
||||||
|
|
||||||
|
# ── Post-call guardrails ──────────────────────────
|
||||||
|
assistant_message.tool_calls = self._cap_delegate_task_calls(
|
||||||
|
assistant_message.tool_calls
|
||||||
|
)
|
||||||
|
assistant_message.tool_calls = self._deduplicate_tool_calls(
|
||||||
|
assistant_message.tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
assistant_msg = self._build_assistant_message(assistant_message, finish_reason)
|
assistant_msg = self._build_assistant_message(assistant_message, finish_reason)
|
||||||
|
|
||||||
# If this turn has both content AND tool_calls, capture the content
|
# If this turn has both content AND tool_calls, capture the content
|
||||||
|
|||||||
263
tests/test_agent_guardrails.py
Normal file
263
tests/test_agent_guardrails.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
"""Unit tests for AIAgent pre/post-LLM-call guardrails.
|
||||||
|
|
||||||
|
Covers three static methods on AIAgent (inspired by PR #1321 — @alireza78a):
|
||||||
|
- _sanitize_api_messages() — Phase 1: orphaned tool pair repair
|
||||||
|
- _cap_delegate_task_calls() — Phase 2a: subagent concurrency limit
|
||||||
|
- _deduplicate_tool_calls() — Phase 2b: identical call deduplication
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
|
||||||
|
from run_agent import AIAgent
|
||||||
|
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def make_tc(name: str, arguments: str = "{}") -> types.SimpleNamespace:
|
||||||
|
"""Create a minimal tool_call SimpleNamespace mirroring the OpenAI SDK object."""
|
||||||
|
tc = types.SimpleNamespace()
|
||||||
|
tc.function = types.SimpleNamespace(name=name, arguments=arguments)
|
||||||
|
return tc
|
||||||
|
|
||||||
|
|
||||||
|
def tool_result(call_id: str, content: str = "ok") -> dict:
|
||||||
|
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||||
|
|
||||||
|
|
||||||
|
def assistant_dict_call(call_id: str, name: str = "terminal") -> dict:
|
||||||
|
"""Dict-style tool_call (as stored in message history)."""
|
||||||
|
return {"id": call_id, "function": {"name": name, "arguments": "{}"}}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Phase 1 — _sanitize_api_messages
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestSanitizeApiMessages:
|
||||||
|
|
||||||
|
def test_orphaned_result_removed(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "assistant", "tool_calls": [assistant_dict_call("c1")]},
|
||||||
|
tool_result("c1"),
|
||||||
|
tool_result("c_ORPHAN"),
|
||||||
|
]
|
||||||
|
out = AIAgent._sanitize_api_messages(msgs)
|
||||||
|
assert len(out) == 2
|
||||||
|
assert all(m.get("tool_call_id") != "c_ORPHAN" for m in out)
|
||||||
|
|
||||||
|
def test_orphaned_call_gets_stub_result(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "assistant", "tool_calls": [assistant_dict_call("c2")]},
|
||||||
|
]
|
||||||
|
out = AIAgent._sanitize_api_messages(msgs)
|
||||||
|
assert len(out) == 2
|
||||||
|
stub = out[1]
|
||||||
|
assert stub["role"] == "tool"
|
||||||
|
assert stub["tool_call_id"] == "c2"
|
||||||
|
assert stub["content"]
|
||||||
|
|
||||||
|
def test_clean_messages_pass_through(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "assistant", "tool_calls": [assistant_dict_call("c3")]},
|
||||||
|
tool_result("c3"),
|
||||||
|
{"role": "assistant", "content": "done"},
|
||||||
|
]
|
||||||
|
out = AIAgent._sanitize_api_messages(msgs)
|
||||||
|
assert out == msgs
|
||||||
|
|
||||||
|
def test_mixed_orphaned_result_and_orphaned_call(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "assistant", "tool_calls": [
|
||||||
|
assistant_dict_call("c4"),
|
||||||
|
assistant_dict_call("c5"),
|
||||||
|
]},
|
||||||
|
tool_result("c4"),
|
||||||
|
tool_result("c_DANGLING"),
|
||||||
|
]
|
||||||
|
out = AIAgent._sanitize_api_messages(msgs)
|
||||||
|
ids = [m.get("tool_call_id") for m in out if m.get("role") == "tool"]
|
||||||
|
assert "c_DANGLING" not in ids
|
||||||
|
assert "c4" in ids
|
||||||
|
assert "c5" in ids
|
||||||
|
|
||||||
|
def test_empty_list_is_safe(self):
|
||||||
|
assert AIAgent._sanitize_api_messages([]) == []
|
||||||
|
|
||||||
|
def test_no_tool_messages(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
{"role": "assistant", "content": "hello"},
|
||||||
|
]
|
||||||
|
out = AIAgent._sanitize_api_messages(msgs)
|
||||||
|
assert out == msgs
|
||||||
|
|
||||||
|
def test_sdk_object_tool_calls(self):
|
||||||
|
tc_obj = types.SimpleNamespace(id="c6", function=types.SimpleNamespace(
|
||||||
|
name="terminal", arguments="{}"
|
||||||
|
))
|
||||||
|
msgs = [
|
||||||
|
{"role": "assistant", "tool_calls": [tc_obj]},
|
||||||
|
]
|
||||||
|
out = AIAgent._sanitize_api_messages(msgs)
|
||||||
|
assert len(out) == 2
|
||||||
|
assert out[1]["tool_call_id"] == "c6"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Phase 2a — _cap_delegate_task_calls
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestCapDelegateTaskCalls:
|
||||||
|
|
||||||
|
def test_excess_delegates_truncated(self):
|
||||||
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
|
||||||
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
||||||
|
delegate_count = sum(1 for tc in out if tc.function.name == "delegate_task")
|
||||||
|
assert delegate_count == MAX_CONCURRENT_CHILDREN
|
||||||
|
|
||||||
|
def test_non_delegate_calls_preserved(self):
|
||||||
|
tcs = (
|
||||||
|
[make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 1)]
|
||||||
|
+ [make_tc("terminal"), make_tc("web_search")]
|
||||||
|
)
|
||||||
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
||||||
|
names = [tc.function.name for tc in out]
|
||||||
|
assert "terminal" in names
|
||||||
|
assert "web_search" in names
|
||||||
|
|
||||||
|
def test_at_limit_passes_through(self):
|
||||||
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN)]
|
||||||
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
||||||
|
assert out is tcs
|
||||||
|
|
||||||
|
def test_below_limit_passes_through(self):
|
||||||
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN - 1)]
|
||||||
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
||||||
|
assert out is tcs
|
||||||
|
|
||||||
|
def test_no_delegate_calls_unchanged(self):
|
||||||
|
tcs = [make_tc("terminal"), make_tc("web_search")]
|
||||||
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
||||||
|
assert out is tcs
|
||||||
|
|
||||||
|
def test_empty_list_safe(self):
|
||||||
|
assert AIAgent._cap_delegate_task_calls([]) == []
|
||||||
|
|
||||||
|
def test_original_list_not_mutated(self):
|
||||||
|
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
|
||||||
|
original_len = len(tcs)
|
||||||
|
AIAgent._cap_delegate_task_calls(tcs)
|
||||||
|
assert len(tcs) == original_len
|
||||||
|
|
||||||
|
def test_interleaved_order_preserved(self):
|
||||||
|
delegates = [make_tc("delegate_task", f'{{"task":"{i}"}}')
|
||||||
|
for i in range(MAX_CONCURRENT_CHILDREN + 1)]
|
||||||
|
t1 = make_tc("terminal", '{"cmd":"ls"}')
|
||||||
|
w1 = make_tc("web_search", '{"q":"x"}')
|
||||||
|
tcs = [delegates[0], t1, delegates[1], w1] + delegates[2:]
|
||||||
|
out = AIAgent._cap_delegate_task_calls(tcs)
|
||||||
|
expected = [delegates[0], t1, delegates[1], w1] + delegates[2:MAX_CONCURRENT_CHILDREN]
|
||||||
|
assert len(out) == len(expected)
|
||||||
|
for i, (actual, exp) in enumerate(zip(out, expected)):
|
||||||
|
assert actual is exp, f"mismatch at index {i}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Phase 2b — _deduplicate_tool_calls
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestDeduplicateToolCalls:
|
||||||
|
|
||||||
|
def test_duplicate_pair_deduplicated(self):
|
||||||
|
tcs = [
|
||||||
|
make_tc("web_search", '{"query":"foo"}'),
|
||||||
|
make_tc("web_search", '{"query":"foo"}'),
|
||||||
|
]
|
||||||
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
||||||
|
assert len(out) == 1
|
||||||
|
|
||||||
|
def test_multiple_duplicates(self):
|
||||||
|
tcs = [
|
||||||
|
make_tc("web_search", '{"q":"a"}'),
|
||||||
|
make_tc("web_search", '{"q":"a"}'),
|
||||||
|
make_tc("terminal", '{"cmd":"ls"}'),
|
||||||
|
make_tc("terminal", '{"cmd":"ls"}'),
|
||||||
|
make_tc("terminal", '{"cmd":"pwd"}'),
|
||||||
|
]
|
||||||
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
||||||
|
assert len(out) == 3
|
||||||
|
|
||||||
|
def test_same_tool_different_args_kept(self):
|
||||||
|
tcs = [
|
||||||
|
make_tc("terminal", '{"cmd":"ls"}'),
|
||||||
|
make_tc("terminal", '{"cmd":"pwd"}'),
|
||||||
|
]
|
||||||
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
||||||
|
assert out is tcs
|
||||||
|
|
||||||
|
def test_different_tools_same_args_kept(self):
|
||||||
|
tcs = [
|
||||||
|
make_tc("tool_a", '{"x":1}'),
|
||||||
|
make_tc("tool_b", '{"x":1}'),
|
||||||
|
]
|
||||||
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
||||||
|
assert out is tcs
|
||||||
|
|
||||||
|
def test_clean_list_unchanged(self):
|
||||||
|
tcs = [
|
||||||
|
make_tc("web_search", '{"q":"x"}'),
|
||||||
|
make_tc("terminal", '{"cmd":"ls"}'),
|
||||||
|
]
|
||||||
|
out = AIAgent._deduplicate_tool_calls(tcs)
|
||||||
|
assert out is tcs
|
||||||
|
|
||||||
|
def test_empty_list_safe(self):
|
||||||
|
assert AIAgent._deduplicate_tool_calls([]) == []
|
||||||
|
|
||||||
|
def test_first_occurrence_kept(self):
|
||||||
|
tc1 = make_tc("terminal", '{"cmd":"ls"}')
|
||||||
|
tc2 = make_tc("terminal", '{"cmd":"ls"}')
|
||||||
|
out = AIAgent._deduplicate_tool_calls([tc1, tc2])
|
||||||
|
assert len(out) == 1
|
||||||
|
assert out[0] is tc1
|
||||||
|
|
||||||
|
def test_original_list_not_mutated(self):
|
||||||
|
tcs = [
|
||||||
|
make_tc("web_search", '{"q":"dup"}'),
|
||||||
|
make_tc("web_search", '{"q":"dup"}'),
|
||||||
|
]
|
||||||
|
original_len = len(tcs)
|
||||||
|
AIAgent._deduplicate_tool_calls(tcs)
|
||||||
|
assert len(tcs) == original_len
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _get_tool_call_id_static
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetToolCallIdStatic:
|
||||||
|
|
||||||
|
def test_dict_with_valid_id(self):
|
||||||
|
assert AIAgent._get_tool_call_id_static({"id": "call_123"}) == "call_123"
|
||||||
|
|
||||||
|
def test_dict_with_none_id(self):
|
||||||
|
assert AIAgent._get_tool_call_id_static({"id": None}) == ""
|
||||||
|
|
||||||
|
def test_dict_without_id_key(self):
|
||||||
|
assert AIAgent._get_tool_call_id_static({"function": {}}) == ""
|
||||||
|
|
||||||
|
def test_object_with_valid_id(self):
|
||||||
|
tc = types.SimpleNamespace(id="call_456")
|
||||||
|
assert AIAgent._get_tool_call_id_static(tc) == "call_456"
|
||||||
|
|
||||||
|
def test_object_with_none_id(self):
|
||||||
|
tc = types.SimpleNamespace(id=None)
|
||||||
|
assert AIAgent._get_tool_call_id_static(tc) == ""
|
||||||
|
|
||||||
|
def test_object_without_id_attr(self):
|
||||||
|
tc = types.SimpleNamespace()
|
||||||
|
assert AIAgent._get_tool_call_id_static(tc) == ""
|
||||||
Reference in New Issue
Block a user