diff --git a/agent/context_compressor.py b/agent/context_compressor.py index 0cd51b06e..c61cf2c5a 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -154,12 +154,15 @@ class ContextCompressor: def _prune_old_tool_results( self, messages: List[Dict[str, Any]], protect_tail_count: int, + protect_tail_tokens: int | None = None, ) -> tuple[List[Dict[str, Any]], int]: """Replace old tool result contents with a short placeholder. - Walks backward from the end, protecting the most recent - ``protect_tail_count`` messages. Older tool results get their - content replaced with a placeholder string. + Walks backward from the end, protecting the most recent messages that + fall within ``protect_tail_tokens`` (when provided) OR the last + ``protect_tail_count`` messages (backward-compatible default). + When both are given, the token budget takes priority and the message + count acts as a hard minimum floor. Returns (pruned_messages, pruned_count). """ @@ -168,7 +171,29 @@ class ContextCompressor: result = [m.copy() for m in messages] pruned = 0 - prune_boundary = len(result) - protect_tail_count + + # Determine the prune boundary + if protect_tail_tokens is not None and protect_tail_tokens > 0: + # Token-budget approach: walk backward accumulating tokens + accumulated = 0 + boundary = len(result) + min_protect = min(protect_tail_count, len(result) - 1) + for i in range(len(result) - 1, -1, -1): + msg = result[i] + content_len = len(msg.get("content") or "") + msg_tokens = content_len // _CHARS_PER_TOKEN + 10 + for tc in msg.get("tool_calls") or []: + if isinstance(tc, dict): + args = tc.get("function", {}).get("arguments", "") + msg_tokens += len(args) // _CHARS_PER_TOKEN + if accumulated + msg_tokens > protect_tail_tokens and (len(result) - i) >= min_protect: + boundary = i + break + accumulated += msg_tokens + boundary = i + prune_boundary = max(boundary, len(result) - min_protect) + else: + prune_boundary = len(result) - protect_tail_count for i in range(prune_boundary): msg = result[i] @@ -533,13 +558,20 @@ Write only the summary body. Do not include any preamble or prefix.""" derived from ``summary_target_ratio * context_length``, so it scales automatically with the model's context window. - Never cuts inside a tool_call/result group. Falls back to the old - ``protect_last_n`` if the budget would protect fewer messages. + Token budget is the primary criterion. A hard minimum of 3 messages + is always protected, but the budget is allowed to exceed by up to + 1.5x to avoid cutting inside an oversized message (tool output, file + read, etc.). If even the minimum 3 messages exceed 1.5x the budget + the cut is placed right after the head so compression still runs. + + Never cuts inside a tool_call/result group. """ if token_budget is None: token_budget = self.tail_token_budget n = len(messages) - min_tail = self.protect_last_n + # Hard minimum: always keep at least 3 messages in the tail + min_tail = min(3, n - head_end - 1) if n - head_end > 1 else 0 + soft_ceiling = int(token_budget * 1.5) accumulated = 0 cut_idx = n # start from beyond the end @@ -552,21 +584,21 @@ Write only the summary body. Do not include any preamble or prefix.""" if isinstance(tc, dict): args = tc.get("function", {}).get("arguments", "") msg_tokens += len(args) // _CHARS_PER_TOKEN - if accumulated + msg_tokens > token_budget and (n - i) >= min_tail: + # Stop once we exceed the soft ceiling (unless we haven't hit min_tail yet) + if accumulated + msg_tokens > soft_ceiling and (n - i) >= min_tail: break accumulated += msg_tokens cut_idx = i - # Ensure we protect at least protect_last_n messages + # Ensure we protect at least min_tail messages fallback_cut = n - min_tail if cut_idx > fallback_cut: cut_idx = fallback_cut # If the token budget would protect everything (small conversations), - # fall back to the fixed protect_last_n approach so compression can - # still remove middle turns. + # force a cut after the head so compression can still remove middle turns. if cut_idx <= head_end: - cut_idx = fallback_cut + cut_idx = max(fallback_cut, head_end + 1) # Align to avoid splitting tool groups cut_idx = self._align_boundary_backward(messages, cut_idx) @@ -591,12 +623,13 @@ Write only the summary body. Do not include any preamble or prefix.""" up so the API never receives mismatched IDs. """ n_messages = len(messages) - if n_messages <= self.protect_first_n + self.protect_last_n + 1: + # Only need head + 3 tail messages minimum (token budget decides the real tail size) + _min_for_compress = self.protect_first_n + 3 + 1 + if n_messages <= _min_for_compress: if not self.quiet_mode: logger.warning( "Cannot compress: only %d messages (need > %d)", - n_messages, - self.protect_first_n + self.protect_last_n + 1, + n_messages, _min_for_compress, ) return messages @@ -604,7 +637,8 @@ Write only the summary body. Do not include any preamble or prefix.""" # Phase 1: Prune old tool results (cheap, no LLM call) messages, pruned_count = self._prune_old_tool_results( - messages, protect_tail_count=self.protect_last_n * 3, + messages, protect_tail_count=self.protect_last_n, + protect_tail_tokens=self.tail_token_budget, ) if pruned_count and not self.quiet_mode: logger.info("Pre-compression: pruned %d old tool result(s)", pruned_count)