fix: use actual API token counts for gateway compression pre-check

Root cause of aggressive gateway compression vs CLI:
- CLI: single AIAgent persists across conversation, uses real API-reported
  prompt_tokens for compression decisions — accurate
- Gateway: each message creates fresh AIAgent, token count discarded after,
  next message pre-check falls back to rough str(msg)//4 estimate which
  overestimates 30-50% on tool-heavy conversations

Fix:
- Add last_prompt_tokens field to SessionEntry — stores the actual
  API-reported prompt token count from the most recent agent turn
- After run_conversation(), extract context_compressor.last_prompt_tokens
  and persist it via update_session()
- Gateway pre-check now uses stored actual tokens when available (exact
  same accuracy as CLI), falling back to rough estimate with 1.4x safety
  factor only for the first message of a session

This makes gateway compression behave identically to CLI compression
for all turns after the first. Reported by TigerHix.
This commit is contained in:
teknium1
2026-03-10 23:28:18 -07:00
parent a35c37a2f9
commit 58dbd81f03
2 changed files with 50 additions and 18 deletions

View File

@@ -950,12 +950,12 @@ class GatewayRunner:
# repeated truncation/context failures. Detect this early and
# compress proactively — before the agent even starts. (#628)
#
# IMPORTANT: This pre-check uses a rough char-based estimate
# (~4 chars/token) which significantly overestimates for
# tool-heavy conversations (code/JSON tokenizes at 5-7+
# chars/token). To avoid premature compression, we apply a
# 1.4x safety factor — the agent's own compression uses actual
# API-reported token counts and handles precise thresholds.
# Token source priority:
# 1. Actual API-reported prompt_tokens from the last turn
# (stored in session_entry.last_prompt_tokens)
# 2. Rough char-based estimate (str(msg)//4) with a 1.4x
# safety factor to account for overestimation on tool-heavy
# conversations (code/JSON tokenizes at 5-7+ chars/token).
# -----------------------------------------------------------------
if history and len(history) >= 4:
from agent.model_metadata import (
@@ -1003,25 +1003,37 @@ class GatewayRunner:
if _hyg_compression_enabled:
_hyg_context_length = get_model_context_length(_hyg_model)
# Apply 1.4x safety factor to account for rough estimate
# overestimation on tool-heavy / code-heavy conversations.
_ROUGH_ESTIMATE_SAFETY = 1.4
_compress_token_threshold = int(
_hyg_context_length * _hyg_threshold_pct * _ROUGH_ESTIMATE_SAFETY
_hyg_context_length * _hyg_threshold_pct
)
# Warn if still huge after compression (95% of context, with same safety factor)
_warn_token_threshold = int(_hyg_context_length * 0.95 * _ROUGH_ESTIMATE_SAFETY)
_warn_token_threshold = int(_hyg_context_length * 0.95)
_msg_count = len(history)
_approx_tokens = estimate_messages_tokens_rough(history)
# Prefer actual API-reported tokens from the last turn
# (stored in session entry) over the rough char-based estimate.
# The rough estimate (str(msg)//4) overestimates by 30-50% on
# tool-heavy/code-heavy conversations, causing premature compression.
_stored_tokens = session_entry.last_prompt_tokens
if _stored_tokens > 0:
_approx_tokens = _stored_tokens
_token_source = "actual"
else:
_approx_tokens = estimate_messages_tokens_rough(history)
# Apply safety factor only for rough estimates
_compress_token_threshold = int(
_compress_token_threshold * 1.4
)
_warn_token_threshold = int(_warn_token_threshold * 1.4)
_token_source = "estimated"
_needs_compress = _approx_tokens >= _compress_token_threshold
if _needs_compress:
logger.info(
"Session hygiene: %s messages, ~%s tokens — auto-compressing "
"Session hygiene: %s messages, ~%s tokens (%s) — auto-compressing "
"(threshold: %s%% of %s = %s tokens)",
_msg_count, f"{_approx_tokens:,}",
_msg_count, f"{_approx_tokens:,}", _token_source,
int(_hyg_threshold_pct * 100),
f"{_hyg_context_length:,}",
f"{_compress_token_threshold:,}",
@@ -1344,8 +1356,11 @@ class GatewayRunner:
skip_db=agent_persisted,
)
# Update session
self.session_store.update_session(session_entry.session_key)
# Update session with actual prompt token count from the agent
self.session_store.update_session(
session_entry.session_key,
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
)
return response
@@ -2908,6 +2923,13 @@ class GatewayRunner:
# Return final response, or a message if something went wrong
final_response = result.get("final_response")
# Extract last actual prompt token count from the agent's compressor
_last_prompt_toks = 0
_agent = agent_holder[0]
if _agent and hasattr(_agent, "context_compressor"):
_last_prompt_toks = getattr(_agent.context_compressor, "last_prompt_tokens", 0)
if not final_response:
error_msg = f"⚠️ {result['error']}" if result.get("error") else "(No response generated)"
return {
@@ -2916,6 +2938,7 @@ class GatewayRunner:
"api_calls": result.get("api_calls", 0),
"tools": tools_holder[0] or [],
"history_offset": len(agent_history),
"last_prompt_tokens": _last_prompt_toks,
}
# Scan tool results for MEDIA:<path> tags that need to be delivered
@@ -2959,6 +2982,7 @@ class GatewayRunner:
"api_calls": result_holder[0].get("api_calls", 0) if result_holder[0] else 0,
"tools": tools_holder[0] or [],
"history_offset": len(agent_history),
"last_prompt_tokens": _last_prompt_toks,
}
# Start progress message sender if enabled

View File

@@ -241,6 +241,9 @@ class SessionEntry:
output_tokens: int = 0
total_tokens: int = 0
# Last API-reported prompt tokens (for accurate compression pre-check)
last_prompt_tokens: int = 0
# Set when a session was created because the previous one expired;
# consumed once by the message handler to inject a notice into context
was_auto_reset: bool = False
@@ -257,6 +260,7 @@ class SessionEntry:
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"total_tokens": self.total_tokens,
"last_prompt_tokens": self.last_prompt_tokens,
}
if self.origin:
result["origin"] = self.origin.to_dict()
@@ -287,6 +291,7 @@ class SessionEntry:
input_tokens=data.get("input_tokens", 0),
output_tokens=data.get("output_tokens", 0),
total_tokens=data.get("total_tokens", 0),
last_prompt_tokens=data.get("last_prompt_tokens", 0),
)
@@ -550,7 +555,8 @@ class SessionStore:
self,
session_key: str,
input_tokens: int = 0,
output_tokens: int = 0
output_tokens: int = 0,
last_prompt_tokens: int = 0,
) -> None:
"""Update a session's metadata after an interaction."""
self._ensure_loaded()
@@ -560,6 +566,8 @@ class SessionStore:
entry.updated_at = datetime.now()
entry.input_tokens += input_tokens
entry.output_tokens += output_tokens
if last_prompt_tokens > 0:
entry.last_prompt_tokens = last_prompt_tokens
entry.total_tokens = entry.input_tokens + entry.output_tokens
self._save()