diff --git a/gateway/run.py b/gateway/run.py index 151ffad13..8458bb9d4 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -950,12 +950,12 @@ class GatewayRunner: # repeated truncation/context failures. Detect this early and # compress proactively — before the agent even starts. (#628) # - # IMPORTANT: This pre-check uses a rough char-based estimate - # (~4 chars/token) which significantly overestimates for - # tool-heavy conversations (code/JSON tokenizes at 5-7+ - # chars/token). To avoid premature compression, we apply a - # 1.4x safety factor — the agent's own compression uses actual - # API-reported token counts and handles precise thresholds. + # Token source priority: + # 1. Actual API-reported prompt_tokens from the last turn + # (stored in session_entry.last_prompt_tokens) + # 2. Rough char-based estimate (str(msg)//4) with a 1.4x + # safety factor to account for overestimation on tool-heavy + # conversations (code/JSON tokenizes at 5-7+ chars/token). # ----------------------------------------------------------------- if history and len(history) >= 4: from agent.model_metadata import ( @@ -1003,25 +1003,37 @@ class GatewayRunner: if _hyg_compression_enabled: _hyg_context_length = get_model_context_length(_hyg_model) - # Apply 1.4x safety factor to account for rough estimate - # overestimation on tool-heavy / code-heavy conversations. - _ROUGH_ESTIMATE_SAFETY = 1.4 _compress_token_threshold = int( - _hyg_context_length * _hyg_threshold_pct * _ROUGH_ESTIMATE_SAFETY + _hyg_context_length * _hyg_threshold_pct ) - # Warn if still huge after compression (95% of context, with same safety factor) - _warn_token_threshold = int(_hyg_context_length * 0.95 * _ROUGH_ESTIMATE_SAFETY) + _warn_token_threshold = int(_hyg_context_length * 0.95) _msg_count = len(history) - _approx_tokens = estimate_messages_tokens_rough(history) + + # Prefer actual API-reported tokens from the last turn + # (stored in session entry) over the rough char-based estimate. + # The rough estimate (str(msg)//4) overestimates by 30-50% on + # tool-heavy/code-heavy conversations, causing premature compression. + _stored_tokens = session_entry.last_prompt_tokens + if _stored_tokens > 0: + _approx_tokens = _stored_tokens + _token_source = "actual" + else: + _approx_tokens = estimate_messages_tokens_rough(history) + # Apply safety factor only for rough estimates + _compress_token_threshold = int( + _compress_token_threshold * 1.4 + ) + _warn_token_threshold = int(_warn_token_threshold * 1.4) + _token_source = "estimated" _needs_compress = _approx_tokens >= _compress_token_threshold if _needs_compress: logger.info( - "Session hygiene: %s messages, ~%s tokens — auto-compressing " + "Session hygiene: %s messages, ~%s tokens (%s) — auto-compressing " "(threshold: %s%% of %s = %s tokens)", - _msg_count, f"{_approx_tokens:,}", + _msg_count, f"{_approx_tokens:,}", _token_source, int(_hyg_threshold_pct * 100), f"{_hyg_context_length:,}", f"{_compress_token_threshold:,}", @@ -1344,8 +1356,11 @@ class GatewayRunner: skip_db=agent_persisted, ) - # Update session - self.session_store.update_session(session_entry.session_key) + # Update session with actual prompt token count from the agent + self.session_store.update_session( + session_entry.session_key, + last_prompt_tokens=agent_result.get("last_prompt_tokens", 0), + ) return response @@ -2908,6 +2923,13 @@ class GatewayRunner: # Return final response, or a message if something went wrong final_response = result.get("final_response") + + # Extract last actual prompt token count from the agent's compressor + _last_prompt_toks = 0 + _agent = agent_holder[0] + if _agent and hasattr(_agent, "context_compressor"): + _last_prompt_toks = getattr(_agent.context_compressor, "last_prompt_tokens", 0) + if not final_response: error_msg = f"⚠️ {result['error']}" if result.get("error") else "(No response generated)" return { @@ -2916,6 +2938,7 @@ class GatewayRunner: "api_calls": result.get("api_calls", 0), "tools": tools_holder[0] or [], "history_offset": len(agent_history), + "last_prompt_tokens": _last_prompt_toks, } # Scan tool results for MEDIA: tags that need to be delivered @@ -2959,6 +2982,7 @@ class GatewayRunner: "api_calls": result_holder[0].get("api_calls", 0) if result_holder[0] else 0, "tools": tools_holder[0] or [], "history_offset": len(agent_history), + "last_prompt_tokens": _last_prompt_toks, } # Start progress message sender if enabled diff --git a/gateway/session.py b/gateway/session.py index 410d24037..e2777fe1a 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -241,6 +241,9 @@ class SessionEntry: output_tokens: int = 0 total_tokens: int = 0 + # Last API-reported prompt tokens (for accurate compression pre-check) + last_prompt_tokens: int = 0 + # Set when a session was created because the previous one expired; # consumed once by the message handler to inject a notice into context was_auto_reset: bool = False @@ -257,6 +260,7 @@ class SessionEntry: "input_tokens": self.input_tokens, "output_tokens": self.output_tokens, "total_tokens": self.total_tokens, + "last_prompt_tokens": self.last_prompt_tokens, } if self.origin: result["origin"] = self.origin.to_dict() @@ -287,6 +291,7 @@ class SessionEntry: input_tokens=data.get("input_tokens", 0), output_tokens=data.get("output_tokens", 0), total_tokens=data.get("total_tokens", 0), + last_prompt_tokens=data.get("last_prompt_tokens", 0), ) @@ -550,7 +555,8 @@ class SessionStore: self, session_key: str, input_tokens: int = 0, - output_tokens: int = 0 + output_tokens: int = 0, + last_prompt_tokens: int = 0, ) -> None: """Update a session's metadata after an interaction.""" self._ensure_loaded() @@ -560,6 +566,8 @@ class SessionStore: entry.updated_at = datetime.now() entry.input_tokens += input_tokens entry.output_tokens += output_tokens + if last_prompt_tokens > 0: + entry.last_prompt_tokens = last_prompt_tokens entry.total_tokens = entry.input_tokens + entry.output_tokens self._save()