fix: use actual API token counts for gateway compression pre-check
Root cause of aggressive gateway compression vs CLI: - CLI: single AIAgent persists across conversation, uses real API-reported prompt_tokens for compression decisions — accurate - Gateway: each message creates fresh AIAgent, token count discarded after, next message pre-check falls back to rough str(msg)//4 estimate which overestimates 30-50% on tool-heavy conversations Fix: - Add last_prompt_tokens field to SessionEntry — stores the actual API-reported prompt token count from the most recent agent turn - After run_conversation(), extract context_compressor.last_prompt_tokens and persist it via update_session() - Gateway pre-check now uses stored actual tokens when available (exact same accuracy as CLI), falling back to rough estimate with 1.4x safety factor only for the first message of a session This makes gateway compression behave identically to CLI compression for all turns after the first. Reported by TigerHix.
This commit is contained in:
@@ -950,12 +950,12 @@ class GatewayRunner:
|
||||
# repeated truncation/context failures. Detect this early and
|
||||
# compress proactively — before the agent even starts. (#628)
|
||||
#
|
||||
# IMPORTANT: This pre-check uses a rough char-based estimate
|
||||
# (~4 chars/token) which significantly overestimates for
|
||||
# tool-heavy conversations (code/JSON tokenizes at 5-7+
|
||||
# chars/token). To avoid premature compression, we apply a
|
||||
# 1.4x safety factor — the agent's own compression uses actual
|
||||
# API-reported token counts and handles precise thresholds.
|
||||
# Token source priority:
|
||||
# 1. Actual API-reported prompt_tokens from the last turn
|
||||
# (stored in session_entry.last_prompt_tokens)
|
||||
# 2. Rough char-based estimate (str(msg)//4) with a 1.4x
|
||||
# safety factor to account for overestimation on tool-heavy
|
||||
# conversations (code/JSON tokenizes at 5-7+ chars/token).
|
||||
# -----------------------------------------------------------------
|
||||
if history and len(history) >= 4:
|
||||
from agent.model_metadata import (
|
||||
@@ -1003,25 +1003,37 @@ class GatewayRunner:
|
||||
|
||||
if _hyg_compression_enabled:
|
||||
_hyg_context_length = get_model_context_length(_hyg_model)
|
||||
# Apply 1.4x safety factor to account for rough estimate
|
||||
# overestimation on tool-heavy / code-heavy conversations.
|
||||
_ROUGH_ESTIMATE_SAFETY = 1.4
|
||||
_compress_token_threshold = int(
|
||||
_hyg_context_length * _hyg_threshold_pct * _ROUGH_ESTIMATE_SAFETY
|
||||
_hyg_context_length * _hyg_threshold_pct
|
||||
)
|
||||
# Warn if still huge after compression (95% of context, with same safety factor)
|
||||
_warn_token_threshold = int(_hyg_context_length * 0.95 * _ROUGH_ESTIMATE_SAFETY)
|
||||
_warn_token_threshold = int(_hyg_context_length * 0.95)
|
||||
|
||||
_msg_count = len(history)
|
||||
_approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
# Prefer actual API-reported tokens from the last turn
|
||||
# (stored in session entry) over the rough char-based estimate.
|
||||
# The rough estimate (str(msg)//4) overestimates by 30-50% on
|
||||
# tool-heavy/code-heavy conversations, causing premature compression.
|
||||
_stored_tokens = session_entry.last_prompt_tokens
|
||||
if _stored_tokens > 0:
|
||||
_approx_tokens = _stored_tokens
|
||||
_token_source = "actual"
|
||||
else:
|
||||
_approx_tokens = estimate_messages_tokens_rough(history)
|
||||
# Apply safety factor only for rough estimates
|
||||
_compress_token_threshold = int(
|
||||
_compress_token_threshold * 1.4
|
||||
)
|
||||
_warn_token_threshold = int(_warn_token_threshold * 1.4)
|
||||
_token_source = "estimated"
|
||||
|
||||
_needs_compress = _approx_tokens >= _compress_token_threshold
|
||||
|
||||
if _needs_compress:
|
||||
logger.info(
|
||||
"Session hygiene: %s messages, ~%s tokens — auto-compressing "
|
||||
"Session hygiene: %s messages, ~%s tokens (%s) — auto-compressing "
|
||||
"(threshold: %s%% of %s = %s tokens)",
|
||||
_msg_count, f"{_approx_tokens:,}",
|
||||
_msg_count, f"{_approx_tokens:,}", _token_source,
|
||||
int(_hyg_threshold_pct * 100),
|
||||
f"{_hyg_context_length:,}",
|
||||
f"{_compress_token_threshold:,}",
|
||||
@@ -1344,8 +1356,11 @@ class GatewayRunner:
|
||||
skip_db=agent_persisted,
|
||||
)
|
||||
|
||||
# Update session
|
||||
self.session_store.update_session(session_entry.session_key)
|
||||
# Update session with actual prompt token count from the agent
|
||||
self.session_store.update_session(
|
||||
session_entry.session_key,
|
||||
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -2908,6 +2923,13 @@ class GatewayRunner:
|
||||
|
||||
# Return final response, or a message if something went wrong
|
||||
final_response = result.get("final_response")
|
||||
|
||||
# Extract last actual prompt token count from the agent's compressor
|
||||
_last_prompt_toks = 0
|
||||
_agent = agent_holder[0]
|
||||
if _agent and hasattr(_agent, "context_compressor"):
|
||||
_last_prompt_toks = getattr(_agent.context_compressor, "last_prompt_tokens", 0)
|
||||
|
||||
if not final_response:
|
||||
error_msg = f"⚠️ {result['error']}" if result.get("error") else "(No response generated)"
|
||||
return {
|
||||
@@ -2916,6 +2938,7 @@ class GatewayRunner:
|
||||
"api_calls": result.get("api_calls", 0),
|
||||
"tools": tools_holder[0] or [],
|
||||
"history_offset": len(agent_history),
|
||||
"last_prompt_tokens": _last_prompt_toks,
|
||||
}
|
||||
|
||||
# Scan tool results for MEDIA:<path> tags that need to be delivered
|
||||
@@ -2959,6 +2982,7 @@ class GatewayRunner:
|
||||
"api_calls": result_holder[0].get("api_calls", 0) if result_holder[0] else 0,
|
||||
"tools": tools_holder[0] or [],
|
||||
"history_offset": len(agent_history),
|
||||
"last_prompt_tokens": _last_prompt_toks,
|
||||
}
|
||||
|
||||
# Start progress message sender if enabled
|
||||
|
||||
@@ -241,6 +241,9 @@ class SessionEntry:
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
# Last API-reported prompt tokens (for accurate compression pre-check)
|
||||
last_prompt_tokens: int = 0
|
||||
|
||||
# Set when a session was created because the previous one expired;
|
||||
# consumed once by the message handler to inject a notice into context
|
||||
was_auto_reset: bool = False
|
||||
@@ -257,6 +260,7 @@ class SessionEntry:
|
||||
"input_tokens": self.input_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
"last_prompt_tokens": self.last_prompt_tokens,
|
||||
}
|
||||
if self.origin:
|
||||
result["origin"] = self.origin.to_dict()
|
||||
@@ -287,6 +291,7 @@ class SessionEntry:
|
||||
input_tokens=data.get("input_tokens", 0),
|
||||
output_tokens=data.get("output_tokens", 0),
|
||||
total_tokens=data.get("total_tokens", 0),
|
||||
last_prompt_tokens=data.get("last_prompt_tokens", 0),
|
||||
)
|
||||
|
||||
|
||||
@@ -550,7 +555,8 @@ class SessionStore:
|
||||
self,
|
||||
session_key: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0
|
||||
output_tokens: int = 0,
|
||||
last_prompt_tokens: int = 0,
|
||||
) -> None:
|
||||
"""Update a session's metadata after an interaction."""
|
||||
self._ensure_loaded()
|
||||
@@ -560,6 +566,8 @@ class SessionStore:
|
||||
entry.updated_at = datetime.now()
|
||||
entry.input_tokens += input_tokens
|
||||
entry.output_tokens += output_tokens
|
||||
if last_prompt_tokens > 0:
|
||||
entry.last_prompt_tokens = last_prompt_tokens
|
||||
entry.total_tokens = entry.input_tokens + entry.output_tokens
|
||||
self._save()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user