From 5da55ea1e32260d90df98265027eb98c7a3765d9 Mon Sep 17 00:00:00 2001 From: teknium1 Date: Sat, 7 Mar 2026 08:08:00 -0800 Subject: [PATCH] fix: sanitize orphaned tool-call/result pairs in message compression Enhance message compression by adding a method to clean up orphaned tool-call and tool-result pairs. This ensures that the API receives well-formed messages, preventing errors related to mismatched IDs. The new functionality includes removing orphaned results and adding stub results for missing calls, improving overall message integrity during compression. --- agent/context_compressor.py | 110 ++++++++++++++++++++++++++++++++++++ run_agent.py | 7 +++ 2 files changed, 117 insertions(+) diff --git a/agent/context_compressor.py b/agent/context_compressor.py index 9c601a1b..798536fb 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -196,10 +196,111 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" logger.debug("Could not build fallback auxiliary client: %s", exc) return None, None + # ------------------------------------------------------------------ + # Tool-call / tool-result pair integrity helpers + # ------------------------------------------------------------------ + + @staticmethod + def _get_tool_call_id(tc) -> str: + """Extract the call ID from a tool_call entry (dict or SimpleNamespace).""" + if isinstance(tc, dict): + return tc.get("id", "") + return getattr(tc, "id", "") or "" + + def _sanitize_tool_pairs(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Fix orphaned tool_call / tool_result pairs after compression. + + Two failure modes: + 1. A tool *result* references a call_id whose assistant tool_call was + removed (summarized/truncated). The API rejects this with + "No tool call found for function call output with call_id ...". + 2. An assistant message has tool_calls whose results were dropped. + The API rejects this because every tool_call must be followed by + a tool result with the matching call_id. + + This method removes orphaned results and inserts stub results for + orphaned calls so the message list is always well-formed. + """ + surviving_call_ids: set = set() + for msg in messages: + if msg.get("role") == "assistant": + for tc in msg.get("tool_calls") or []: + cid = self._get_tool_call_id(tc) + if cid: + surviving_call_ids.add(cid) + + result_call_ids: set = set() + for msg in messages: + if msg.get("role") == "tool": + cid = msg.get("tool_call_id") + if cid: + result_call_ids.add(cid) + + # 1. Remove tool results whose call_id has no matching assistant tool_call + orphaned_results = result_call_ids - surviving_call_ids + if orphaned_results: + messages = [ + m for m in messages + if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results) + ] + if not self.quiet_mode: + logger.info("Compression sanitizer: removed %d orphaned tool result(s)", len(orphaned_results)) + + # 2. Add stub results for assistant tool_calls whose results were dropped + missing_results = surviving_call_ids - result_call_ids + if missing_results: + patched: List[Dict[str, Any]] = [] + for msg in messages: + patched.append(msg) + if msg.get("role") == "assistant": + for tc in msg.get("tool_calls") or []: + cid = self._get_tool_call_id(tc) + if cid in missing_results: + patched.append({ + "role": "tool", + "content": "[Result from earlier conversation — see context summary above]", + "tool_call_id": cid, + }) + messages = patched + if not self.quiet_mode: + logger.info("Compression sanitizer: added %d stub tool result(s)", len(missing_results)) + + return messages + + def _align_boundary_forward(self, messages: List[Dict[str, Any]], idx: int) -> int: + """Push a compress-start boundary forward past any orphan tool results. + + If ``messages[idx]`` is a tool result, slide forward until we hit a + non-tool message so we don't start the summarised region mid-group. + """ + while idx < len(messages) and messages[idx].get("role") == "tool": + idx += 1 + return idx + + def _align_boundary_backward(self, messages: List[Dict[str, Any]], idx: int) -> int: + """Pull a compress-end boundary backward to avoid splitting a + tool_call / result group. + + If the message just before ``idx`` is an assistant message with + tool_calls, those tool results will start at ``idx`` and would be + separated from their parent. Move backwards to include the whole + group in the summarised region. + """ + if idx <= 0 or idx >= len(messages): + return idx + prev = messages[idx - 1] + if prev.get("role") == "assistant" and prev.get("tool_calls"): + # The results for this assistant turn sit at idx..idx+k. + # Include the assistant message in the summarised region too. + idx -= 1 + return idx + def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]: """Compress conversation messages by summarizing middle turns. Keeps first N + last N turns, summarizes everything in between. + After compression, orphaned tool_call / tool_result pairs are cleaned + up so the API never receives mismatched IDs. """ n_messages = len(messages) if n_messages <= self.protect_first_n + self.protect_last_n + 1: @@ -212,6 +313,12 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" if compress_start >= compress_end: return messages + # Adjust boundaries to avoid splitting tool_call/result groups. + compress_start = self._align_boundary_forward(messages, compress_start) + compress_end = self._align_boundary_backward(messages, compress_end) + if compress_start >= compress_end: + return messages + turns_to_summarize = messages[compress_start:compress_end] display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages) @@ -233,6 +340,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" tail = messages[-self.protect_last_n:] kept.extend(m.copy() for m in tail) self.compression_count += 1 + kept = self._sanitize_tool_pairs(kept) if not self.quiet_mode: print(f" ✂️ Truncated: {len(messages)} → {len(kept)} messages (dropped middle turns)") return kept @@ -256,6 +364,8 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" self.compression_count += 1 + compressed = self._sanitize_tool_pairs(compressed) + if not self.quiet_mode: new_estimate = estimate_messages_tokens_rough(compressed) saved_estimate = display_tokens - new_estimate diff --git a/run_agent.py b/run_agent.py index 0ee89d7d..84ac1b5e 100644 --- a/run_agent.py +++ b/run_agent.py @@ -3059,6 +3059,13 @@ class AIAgent: if self._use_prompt_caching: api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl) + # Safety net: strip orphaned tool results / add stubs for missing + # results before sending to the API. The compressor handles this + # during compression, but orphans can also sneak in from session + # loading or manual message manipulation. + if hasattr(self, 'context_compressor') and self.context_compressor: + api_messages = self.context_compressor._sanitize_tool_pairs(api_messages) + # Calculate approximate request size for logging total_chars = sum(len(str(msg)) for msg in api_messages) approx_tokens = total_chars // 4 # Rough estimate: 4 chars per token