diff --git a/run_agent.py b/run_agent.py index 1133d0771..bfe62e04c 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1957,7 +1957,124 @@ class AIAgent: prompt_parts.append(PLATFORM_HINTS[platform_key]) return "\n\n".join(prompt_parts) - + + # ========================================================================= + # Pre/post-call guardrails (inspired by PR #1321 — @alireza78a) + # ========================================================================= + + @staticmethod + def _get_tool_call_id_static(tc) -> str: + """Extract call ID from a tool_call entry (dict or object).""" + if isinstance(tc, dict): + return tc.get("id", "") or "" + return getattr(tc, "id", "") or "" + + @staticmethod + def _sanitize_api_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Fix orphaned tool_call / tool_result pairs before every LLM call. + + Runs unconditionally — not gated on whether the context compressor + is present — so orphans from session loading or manual message + manipulation are always caught. + """ + surviving_call_ids: set = set() + for msg in messages: + if msg.get("role") == "assistant": + for tc in msg.get("tool_calls") or []: + cid = AIAgent._get_tool_call_id_static(tc) + if cid: + surviving_call_ids.add(cid) + + result_call_ids: set = set() + for msg in messages: + if msg.get("role") == "tool": + cid = msg.get("tool_call_id") + if cid: + result_call_ids.add(cid) + + # 1. Drop tool results with no matching assistant call + orphaned_results = result_call_ids - surviving_call_ids + if orphaned_results: + messages = [ + m for m in messages + if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results) + ] + logger.debug( + "Pre-call sanitizer: removed %d orphaned tool result(s)", + len(orphaned_results), + ) + + # 2. Inject stub results for calls whose result was dropped + missing_results = surviving_call_ids - result_call_ids + if missing_results: + patched: List[Dict[str, Any]] = [] + for msg in messages: + patched.append(msg) + if msg.get("role") == "assistant": + for tc in msg.get("tool_calls") or []: + cid = AIAgent._get_tool_call_id_static(tc) + if cid in missing_results: + patched.append({ + "role": "tool", + "content": "[Result unavailable — see context summary above]", + "tool_call_id": cid, + }) + messages = patched + logger.debug( + "Pre-call sanitizer: added %d stub tool result(s)", + len(missing_results), + ) + + return messages + + @staticmethod + def _cap_delegate_task_calls(tool_calls: list) -> list: + """Truncate excess delegate_task calls to MAX_CONCURRENT_CHILDREN. + + The delegate_tool caps the task list inside a single call, but the + model can emit multiple separate delegate_task tool_calls in one + turn. This truncates the excess, preserving all non-delegate calls. + + Returns the original list if no truncation was needed. + """ + from tools.delegate_tool import MAX_CONCURRENT_CHILDREN + delegate_count = sum(1 for tc in tool_calls if tc.function.name == "delegate_task") + if delegate_count <= MAX_CONCURRENT_CHILDREN: + return tool_calls + kept_delegates = 0 + truncated = [] + for tc in tool_calls: + if tc.function.name == "delegate_task": + if kept_delegates < MAX_CONCURRENT_CHILDREN: + truncated.append(tc) + kept_delegates += 1 + else: + truncated.append(tc) + logger.warning( + "Truncated %d excess delegate_task call(s) to enforce " + "MAX_CONCURRENT_CHILDREN=%d limit", + delegate_count - MAX_CONCURRENT_CHILDREN, MAX_CONCURRENT_CHILDREN, + ) + return truncated + + @staticmethod + def _deduplicate_tool_calls(tool_calls: list) -> list: + """Remove duplicate (tool_name, arguments) pairs within a single turn. + + Only the first occurrence of each unique pair is kept. + Returns the original list if no duplicates were found. + """ + seen: set = set() + unique: list = [] + for tc in tool_calls: + key = (tc.function.name, tc.function.arguments) + if key not in seen: + seen.add(key) + unique.append(tc) + else: + logger.warning("Removed duplicate tool call: %s", tc.function.name) + return unique if len(unique) < len(tool_calls) else tool_calls + def _repair_tool_call(self, tool_name: str) -> str | None: """Attempt to repair a mismatched tool name before aborting. @@ -4992,11 +5109,10 @@ class AIAgent: api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl) # Safety net: strip orphaned tool results / add stubs for missing - # results before sending to the API. The compressor handles this - # during compression, but orphans can also sneak in from session - # loading or manual message manipulation. - if hasattr(self, 'context_compressor') and self.context_compressor: - api_messages = self.context_compressor._sanitize_tool_pairs(api_messages) + # results before sending to the API. Runs unconditionally — not + # gated on context_compressor — so orphans from session loading or + # manual message manipulation are always caught. + api_messages = self._sanitize_api_messages(api_messages) # Calculate approximate request size for logging total_chars = sum(len(str(msg)) for msg in api_messages) @@ -6026,7 +6142,15 @@ class AIAgent: # Reset retry counter on successful JSON validation self._invalid_json_retries = 0 - + + # ── Post-call guardrails ────────────────────────── + assistant_message.tool_calls = self._cap_delegate_task_calls( + assistant_message.tool_calls + ) + assistant_message.tool_calls = self._deduplicate_tool_calls( + assistant_message.tool_calls + ) + assistant_msg = self._build_assistant_message(assistant_message, finish_reason) # If this turn has both content AND tool_calls, capture the content diff --git a/tests/test_agent_guardrails.py b/tests/test_agent_guardrails.py new file mode 100644 index 000000000..706b1daf8 --- /dev/null +++ b/tests/test_agent_guardrails.py @@ -0,0 +1,263 @@ +"""Unit tests for AIAgent pre/post-LLM-call guardrails. + +Covers three static methods on AIAgent (inspired by PR #1321 — @alireza78a): + - _sanitize_api_messages() — Phase 1: orphaned tool pair repair + - _cap_delegate_task_calls() — Phase 2a: subagent concurrency limit + - _deduplicate_tool_calls() — Phase 2b: identical call deduplication +""" + +import types + +from run_agent import AIAgent +from tools.delegate_tool import MAX_CONCURRENT_CHILDREN + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def make_tc(name: str, arguments: str = "{}") -> types.SimpleNamespace: + """Create a minimal tool_call SimpleNamespace mirroring the OpenAI SDK object.""" + tc = types.SimpleNamespace() + tc.function = types.SimpleNamespace(name=name, arguments=arguments) + return tc + + +def tool_result(call_id: str, content: str = "ok") -> dict: + return {"role": "tool", "tool_call_id": call_id, "content": content} + + +def assistant_dict_call(call_id: str, name: str = "terminal") -> dict: + """Dict-style tool_call (as stored in message history).""" + return {"id": call_id, "function": {"name": name, "arguments": "{}"}} + + +# --------------------------------------------------------------------------- +# Phase 1 — _sanitize_api_messages +# --------------------------------------------------------------------------- + +class TestSanitizeApiMessages: + + def test_orphaned_result_removed(self): + msgs = [ + {"role": "assistant", "tool_calls": [assistant_dict_call("c1")]}, + tool_result("c1"), + tool_result("c_ORPHAN"), + ] + out = AIAgent._sanitize_api_messages(msgs) + assert len(out) == 2 + assert all(m.get("tool_call_id") != "c_ORPHAN" for m in out) + + def test_orphaned_call_gets_stub_result(self): + msgs = [ + {"role": "assistant", "tool_calls": [assistant_dict_call("c2")]}, + ] + out = AIAgent._sanitize_api_messages(msgs) + assert len(out) == 2 + stub = out[1] + assert stub["role"] == "tool" + assert stub["tool_call_id"] == "c2" + assert stub["content"] + + def test_clean_messages_pass_through(self): + msgs = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "tool_calls": [assistant_dict_call("c3")]}, + tool_result("c3"), + {"role": "assistant", "content": "done"}, + ] + out = AIAgent._sanitize_api_messages(msgs) + assert out == msgs + + def test_mixed_orphaned_result_and_orphaned_call(self): + msgs = [ + {"role": "assistant", "tool_calls": [ + assistant_dict_call("c4"), + assistant_dict_call("c5"), + ]}, + tool_result("c4"), + tool_result("c_DANGLING"), + ] + out = AIAgent._sanitize_api_messages(msgs) + ids = [m.get("tool_call_id") for m in out if m.get("role") == "tool"] + assert "c_DANGLING" not in ids + assert "c4" in ids + assert "c5" in ids + + def test_empty_list_is_safe(self): + assert AIAgent._sanitize_api_messages([]) == [] + + def test_no_tool_messages(self): + msgs = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + out = AIAgent._sanitize_api_messages(msgs) + assert out == msgs + + def test_sdk_object_tool_calls(self): + tc_obj = types.SimpleNamespace(id="c6", function=types.SimpleNamespace( + name="terminal", arguments="{}" + )) + msgs = [ + {"role": "assistant", "tool_calls": [tc_obj]}, + ] + out = AIAgent._sanitize_api_messages(msgs) + assert len(out) == 2 + assert out[1]["tool_call_id"] == "c6" + + +# --------------------------------------------------------------------------- +# Phase 2a — _cap_delegate_task_calls +# --------------------------------------------------------------------------- + +class TestCapDelegateTaskCalls: + + def test_excess_delegates_truncated(self): + tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)] + out = AIAgent._cap_delegate_task_calls(tcs) + delegate_count = sum(1 for tc in out if tc.function.name == "delegate_task") + assert delegate_count == MAX_CONCURRENT_CHILDREN + + def test_non_delegate_calls_preserved(self): + tcs = ( + [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 1)] + + [make_tc("terminal"), make_tc("web_search")] + ) + out = AIAgent._cap_delegate_task_calls(tcs) + names = [tc.function.name for tc in out] + assert "terminal" in names + assert "web_search" in names + + def test_at_limit_passes_through(self): + tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN)] + out = AIAgent._cap_delegate_task_calls(tcs) + assert out is tcs + + def test_below_limit_passes_through(self): + tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN - 1)] + out = AIAgent._cap_delegate_task_calls(tcs) + assert out is tcs + + def test_no_delegate_calls_unchanged(self): + tcs = [make_tc("terminal"), make_tc("web_search")] + out = AIAgent._cap_delegate_task_calls(tcs) + assert out is tcs + + def test_empty_list_safe(self): + assert AIAgent._cap_delegate_task_calls([]) == [] + + def test_original_list_not_mutated(self): + tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)] + original_len = len(tcs) + AIAgent._cap_delegate_task_calls(tcs) + assert len(tcs) == original_len + + def test_interleaved_order_preserved(self): + delegates = [make_tc("delegate_task", f'{{"task":"{i}"}}') + for i in range(MAX_CONCURRENT_CHILDREN + 1)] + t1 = make_tc("terminal", '{"cmd":"ls"}') + w1 = make_tc("web_search", '{"q":"x"}') + tcs = [delegates[0], t1, delegates[1], w1] + delegates[2:] + out = AIAgent._cap_delegate_task_calls(tcs) + expected = [delegates[0], t1, delegates[1], w1] + delegates[2:MAX_CONCURRENT_CHILDREN] + assert len(out) == len(expected) + for i, (actual, exp) in enumerate(zip(out, expected)): + assert actual is exp, f"mismatch at index {i}" + + +# --------------------------------------------------------------------------- +# Phase 2b — _deduplicate_tool_calls +# --------------------------------------------------------------------------- + +class TestDeduplicateToolCalls: + + def test_duplicate_pair_deduplicated(self): + tcs = [ + make_tc("web_search", '{"query":"foo"}'), + make_tc("web_search", '{"query":"foo"}'), + ] + out = AIAgent._deduplicate_tool_calls(tcs) + assert len(out) == 1 + + def test_multiple_duplicates(self): + tcs = [ + make_tc("web_search", '{"q":"a"}'), + make_tc("web_search", '{"q":"a"}'), + make_tc("terminal", '{"cmd":"ls"}'), + make_tc("terminal", '{"cmd":"ls"}'), + make_tc("terminal", '{"cmd":"pwd"}'), + ] + out = AIAgent._deduplicate_tool_calls(tcs) + assert len(out) == 3 + + def test_same_tool_different_args_kept(self): + tcs = [ + make_tc("terminal", '{"cmd":"ls"}'), + make_tc("terminal", '{"cmd":"pwd"}'), + ] + out = AIAgent._deduplicate_tool_calls(tcs) + assert out is tcs + + def test_different_tools_same_args_kept(self): + tcs = [ + make_tc("tool_a", '{"x":1}'), + make_tc("tool_b", '{"x":1}'), + ] + out = AIAgent._deduplicate_tool_calls(tcs) + assert out is tcs + + def test_clean_list_unchanged(self): + tcs = [ + make_tc("web_search", '{"q":"x"}'), + make_tc("terminal", '{"cmd":"ls"}'), + ] + out = AIAgent._deduplicate_tool_calls(tcs) + assert out is tcs + + def test_empty_list_safe(self): + assert AIAgent._deduplicate_tool_calls([]) == [] + + def test_first_occurrence_kept(self): + tc1 = make_tc("terminal", '{"cmd":"ls"}') + tc2 = make_tc("terminal", '{"cmd":"ls"}') + out = AIAgent._deduplicate_tool_calls([tc1, tc2]) + assert len(out) == 1 + assert out[0] is tc1 + + def test_original_list_not_mutated(self): + tcs = [ + make_tc("web_search", '{"q":"dup"}'), + make_tc("web_search", '{"q":"dup"}'), + ] + original_len = len(tcs) + AIAgent._deduplicate_tool_calls(tcs) + assert len(tcs) == original_len + + +# --------------------------------------------------------------------------- +# _get_tool_call_id_static +# --------------------------------------------------------------------------- + +class TestGetToolCallIdStatic: + + def test_dict_with_valid_id(self): + assert AIAgent._get_tool_call_id_static({"id": "call_123"}) == "call_123" + + def test_dict_with_none_id(self): + assert AIAgent._get_tool_call_id_static({"id": None}) == "" + + def test_dict_without_id_key(self): + assert AIAgent._get_tool_call_id_static({"function": {}}) == "" + + def test_object_with_valid_id(self): + tc = types.SimpleNamespace(id="call_456") + assert AIAgent._get_tool_call_id_static(tc) == "call_456" + + def test_object_with_none_id(self): + tc = types.SimpleNamespace(id=None) + assert AIAgent._get_tool_call_id_static(tc) == "" + + def test_object_without_id_attr(self): + tc = types.SimpleNamespace() + assert AIAgent._get_tool_call_id_static(tc) == ""