fix(compaction): token-budget primary tail protection
Tail protection was effectively message-count based despite having a token budget, because protect_last_n=20 acted as a hard floor. A single 50K-token tool output would cause all 20 recent messages to be preserved regardless of budget, leaving little room for summarization. Changes: - _find_tail_cut_by_tokens: min_tail reduced from protect_last_n (20) to 3; token budget is now the primary criterion - Soft ceiling at 1.5x budget to avoid cutting mid-oversized-message - _prune_old_tool_results: accepts optional protect_tail_tokens so pruning also respects the token budget instead of a fixed count - compress() minimum message check relaxed from protect_first_n + protect_last_n + 1 to protect_first_n + 3 + 1 - Tool group alignment (no splitting tool_call/result) preserved
This commit is contained in:
@@ -154,12 +154,15 @@ class ContextCompressor:
|
||||
|
||||
def _prune_old_tool_results(
|
||||
self, messages: List[Dict[str, Any]], protect_tail_count: int,
|
||||
protect_tail_tokens: int | None = None,
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
"""Replace old tool result contents with a short placeholder.
|
||||
|
||||
Walks backward from the end, protecting the most recent
|
||||
``protect_tail_count`` messages. Older tool results get their
|
||||
content replaced with a placeholder string.
|
||||
Walks backward from the end, protecting the most recent messages that
|
||||
fall within ``protect_tail_tokens`` (when provided) OR the last
|
||||
``protect_tail_count`` messages (backward-compatible default).
|
||||
When both are given, the token budget takes priority and the message
|
||||
count acts as a hard minimum floor.
|
||||
|
||||
Returns (pruned_messages, pruned_count).
|
||||
"""
|
||||
@@ -168,7 +171,29 @@ class ContextCompressor:
|
||||
|
||||
result = [m.copy() for m in messages]
|
||||
pruned = 0
|
||||
prune_boundary = len(result) - protect_tail_count
|
||||
|
||||
# Determine the prune boundary
|
||||
if protect_tail_tokens is not None and protect_tail_tokens > 0:
|
||||
# Token-budget approach: walk backward accumulating tokens
|
||||
accumulated = 0
|
||||
boundary = len(result)
|
||||
min_protect = min(protect_tail_count, len(result) - 1)
|
||||
for i in range(len(result) - 1, -1, -1):
|
||||
msg = result[i]
|
||||
content_len = len(msg.get("content") or "")
|
||||
msg_tokens = content_len // _CHARS_PER_TOKEN + 10
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict):
|
||||
args = tc.get("function", {}).get("arguments", "")
|
||||
msg_tokens += len(args) // _CHARS_PER_TOKEN
|
||||
if accumulated + msg_tokens > protect_tail_tokens and (len(result) - i) >= min_protect:
|
||||
boundary = i
|
||||
break
|
||||
accumulated += msg_tokens
|
||||
boundary = i
|
||||
prune_boundary = max(boundary, len(result) - min_protect)
|
||||
else:
|
||||
prune_boundary = len(result) - protect_tail_count
|
||||
|
||||
for i in range(prune_boundary):
|
||||
msg = result[i]
|
||||
@@ -533,13 +558,20 @@ Write only the summary body. Do not include any preamble or prefix."""
|
||||
derived from ``summary_target_ratio * context_length``, so it
|
||||
scales automatically with the model's context window.
|
||||
|
||||
Never cuts inside a tool_call/result group. Falls back to the old
|
||||
``protect_last_n`` if the budget would protect fewer messages.
|
||||
Token budget is the primary criterion. A hard minimum of 3 messages
|
||||
is always protected, but the budget is allowed to exceed by up to
|
||||
1.5x to avoid cutting inside an oversized message (tool output, file
|
||||
read, etc.). If even the minimum 3 messages exceed 1.5x the budget
|
||||
the cut is placed right after the head so compression still runs.
|
||||
|
||||
Never cuts inside a tool_call/result group.
|
||||
"""
|
||||
if token_budget is None:
|
||||
token_budget = self.tail_token_budget
|
||||
n = len(messages)
|
||||
min_tail = self.protect_last_n
|
||||
# Hard minimum: always keep at least 3 messages in the tail
|
||||
min_tail = min(3, n - head_end - 1) if n - head_end > 1 else 0
|
||||
soft_ceiling = int(token_budget * 1.5)
|
||||
accumulated = 0
|
||||
cut_idx = n # start from beyond the end
|
||||
|
||||
@@ -552,21 +584,21 @@ Write only the summary body. Do not include any preamble or prefix."""
|
||||
if isinstance(tc, dict):
|
||||
args = tc.get("function", {}).get("arguments", "")
|
||||
msg_tokens += len(args) // _CHARS_PER_TOKEN
|
||||
if accumulated + msg_tokens > token_budget and (n - i) >= min_tail:
|
||||
# Stop once we exceed the soft ceiling (unless we haven't hit min_tail yet)
|
||||
if accumulated + msg_tokens > soft_ceiling and (n - i) >= min_tail:
|
||||
break
|
||||
accumulated += msg_tokens
|
||||
cut_idx = i
|
||||
|
||||
# Ensure we protect at least protect_last_n messages
|
||||
# Ensure we protect at least min_tail messages
|
||||
fallback_cut = n - min_tail
|
||||
if cut_idx > fallback_cut:
|
||||
cut_idx = fallback_cut
|
||||
|
||||
# If the token budget would protect everything (small conversations),
|
||||
# fall back to the fixed protect_last_n approach so compression can
|
||||
# still remove middle turns.
|
||||
# force a cut after the head so compression can still remove middle turns.
|
||||
if cut_idx <= head_end:
|
||||
cut_idx = fallback_cut
|
||||
cut_idx = max(fallback_cut, head_end + 1)
|
||||
|
||||
# Align to avoid splitting tool groups
|
||||
cut_idx = self._align_boundary_backward(messages, cut_idx)
|
||||
@@ -591,12 +623,13 @@ Write only the summary body. Do not include any preamble or prefix."""
|
||||
up so the API never receives mismatched IDs.
|
||||
"""
|
||||
n_messages = len(messages)
|
||||
if n_messages <= self.protect_first_n + self.protect_last_n + 1:
|
||||
# Only need head + 3 tail messages minimum (token budget decides the real tail size)
|
||||
_min_for_compress = self.protect_first_n + 3 + 1
|
||||
if n_messages <= _min_for_compress:
|
||||
if not self.quiet_mode:
|
||||
logger.warning(
|
||||
"Cannot compress: only %d messages (need > %d)",
|
||||
n_messages,
|
||||
self.protect_first_n + self.protect_last_n + 1,
|
||||
n_messages, _min_for_compress,
|
||||
)
|
||||
return messages
|
||||
|
||||
@@ -604,7 +637,8 @@ Write only the summary body. Do not include any preamble or prefix."""
|
||||
|
||||
# Phase 1: Prune old tool results (cheap, no LLM call)
|
||||
messages, pruned_count = self._prune_old_tool_results(
|
||||
messages, protect_tail_count=self.protect_last_n * 3,
|
||||
messages, protect_tail_count=self.protect_last_n,
|
||||
protect_tail_tokens=self.tail_token_budget,
|
||||
)
|
||||
if pruned_count and not self.quiet_mode:
|
||||
logger.info("Pre-compression: pruned %d old tool result(s)", pruned_count)
|
||||
|
||||
Reference in New Issue
Block a user